@@ -155,6 +155,19 @@ struct OpEventData {
155
155
torch::profiler::impl::CUDAEventStub cuda_event_end_ = nullptr ;
156
156
};
157
157
158
+ struct MemoryEventData {
159
+ int64_t start_time;
160
+ void * ptr;
161
+ int64_t alloc_size;
162
+ int64_t total_allocated;
163
+ int64_t total_reserved;
164
+ uint64_t threadID;
165
+ torch::profiler::impl::kineto::DeviceAndResource kineto_info;
166
+ c10::DeviceType device_type;
167
+ c10::DeviceIndex device_index;
168
+ };
169
+ static_assert (std::is_pod<MemoryEventData>::value, " Non-POD member of MemoryEventData." );
170
+
158
171
// Assumption: Total threads number will not exceed 2^16-1, and total ops will
159
172
// not exceed 2^48 -1.
160
173
static inline uint64_t getForwardThreadKey (uint64_t tid, uint64_t seqNr) {
@@ -204,29 +217,16 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
204
217
int64_t total_reserved,
205
218
c10::Device device) override {
206
219
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
207
- std::lock_guard<std::mutex> guard (state_mutex_);
208
- auto start_time = getTimeUs ();
209
- if (cpu_trace_) {
210
- torch::profiler::impl::kineto::recordThreadInfo ();
211
- cpu_trace_.addMemoryUsageActivity (
212
- kMemoryEventName ,
213
- torch::profiler::impl::kineto::kineto_ids (),
214
- start_time,
215
- device,
216
- ptr,
217
- alloc_size,
218
- total_allocated,
219
- total_reserved);
220
- }
221
-
222
- kineto_events_.emplace_back ();
223
- auto & evt = kineto_events_.back ();
224
- evt.name (kMemoryEventName )
225
- .startUs (start_time)
226
- .deviceIndex (device.index ())
227
- .deviceType (device.type ())
228
- .nBytes (alloc_size)
229
- .startThreadId (at::RecordFunction::currentThreadId ());
220
+ memory_events_.push_back (
221
+ {getTimeUs (),
222
+ ptr,
223
+ alloc_size,
224
+ total_allocated,
225
+ total_reserved,
226
+ at::RecordFunction::currentThreadId (),
227
+ torch::profiler::impl::kineto::kineto_ids (),
228
+ device.type (),
229
+ device.index ()});
230
230
}
231
231
}
232
232
@@ -264,6 +264,28 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
264
264
265
265
void materializeOpEvents () {
266
266
std::lock_guard<std::mutex> guard (state_mutex_);
267
+
268
+ for (const auto & e : memory_events_) {
269
+ cpu_trace_.addMemoryUsageActivity (
270
+ kMemoryEventName ,
271
+ e.kineto_info ,
272
+ e.start_time ,
273
+ c10::Device (e.device_type , e.device_index ),
274
+ e.ptr ,
275
+ e.alloc_size ,
276
+ e.total_allocated ,
277
+ e.total_reserved );
278
+
279
+ kineto_events_.emplace_back ();
280
+ auto & evt = kineto_events_.back ();
281
+ evt.name (kMemoryEventName )
282
+ .startUs (e.start_time )
283
+ .deviceIndex (e.device_index )
284
+ .deviceType (e.device_type )
285
+ .nBytes (e.alloc_size )
286
+ .startThreadId (e.threadID );
287
+ }
288
+
267
289
for (const auto & e : op_events_) {
268
290
if (e.end_us_ < e.start_us_ ) {
269
291
// We initialize end_us_ to the smallest int64_t, so this means that
@@ -585,6 +607,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
585
607
uint64_t start_time_;
586
608
std::set<torch::profiler::impl::ActivityType> activities_;
587
609
std::deque<OpEventData> op_events_;
610
+ std::deque<MemoryEventData> memory_events_;
588
611
torch::profiler::impl::kineto::TraceWrapper cpu_trace_;
589
612
std::vector<KinetoEvent> kineto_events_;
590
613
// Optional, if event post-processing is enabled.
0 commit comments