Skip to content

Commit 9f541aa

Browse files
Taylor Robiepytorchmergebot
Taylor Robie
authored andcommitted
[Profiler] Optimize reportMemoryUsage (#71538)
Summary: Pull Request resolved: #71538 `reportMemoryUsage` is kind of awful. It does a bunch of string writes and such that makes it VERY expensive. Just moving that work off the hot path reduces the overhead for `profile_memory` from ~6.5 us to ~1.2 us. (85% reduction in the kineto contribution to profiling overhead.) Test Plan: Ran ubenchmark with `--op empty --stressTestKineto --kinetoProfileMemory` Reviewed By: swolchok Differential Revision: D32730167 fbshipit-source-id: fe18e8fa3881967cad8fa1c26c71c805e9b034e5 (cherry picked from commit 0d394cb)
1 parent 24c91e2 commit 9f541aa

File tree

1 file changed

+46
-23
lines changed

1 file changed

+46
-23
lines changed

torch/csrc/autograd/profiler_kineto.cpp

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ struct OpEventData {
155155
torch::profiler::impl::CUDAEventStub cuda_event_end_ = nullptr;
156156
};
157157

158+
struct MemoryEventData {
159+
int64_t start_time;
160+
void* ptr;
161+
int64_t alloc_size;
162+
int64_t total_allocated;
163+
int64_t total_reserved;
164+
uint64_t threadID;
165+
torch::profiler::impl::kineto::DeviceAndResource kineto_info;
166+
c10::DeviceType device_type;
167+
c10::DeviceIndex device_index;
168+
};
169+
static_assert(std::is_pod<MemoryEventData>::value, "Non-POD member of MemoryEventData.");
170+
158171
// Assumption: Total threads number will not exceed 2^16-1, and total ops will
159172
// not exceed 2^48 -1.
160173
static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
@@ -204,29 +217,16 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
204217
int64_t total_reserved,
205218
c10::Device device) override {
206219
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
207-
std::lock_guard<std::mutex> guard(state_mutex_);
208-
auto start_time = getTimeUs();
209-
if (cpu_trace_) {
210-
torch::profiler::impl::kineto::recordThreadInfo();
211-
cpu_trace_.addMemoryUsageActivity(
212-
kMemoryEventName,
213-
torch::profiler::impl::kineto::kineto_ids(),
214-
start_time,
215-
device,
216-
ptr,
217-
alloc_size,
218-
total_allocated,
219-
total_reserved);
220-
}
221-
222-
kineto_events_.emplace_back();
223-
auto& evt = kineto_events_.back();
224-
evt.name(kMemoryEventName)
225-
.startUs(start_time)
226-
.deviceIndex(device.index())
227-
.deviceType(device.type())
228-
.nBytes(alloc_size)
229-
.startThreadId(at::RecordFunction::currentThreadId());
220+
memory_events_.push_back(
221+
{getTimeUs(),
222+
ptr,
223+
alloc_size,
224+
total_allocated,
225+
total_reserved,
226+
at::RecordFunction::currentThreadId(),
227+
torch::profiler::impl::kineto::kineto_ids(),
228+
device.type(),
229+
device.index()});
230230
}
231231
}
232232

@@ -264,6 +264,28 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
264264

265265
void materializeOpEvents() {
266266
std::lock_guard<std::mutex> guard(state_mutex_);
267+
268+
for (const auto& e : memory_events_) {
269+
cpu_trace_.addMemoryUsageActivity(
270+
kMemoryEventName,
271+
e.kineto_info,
272+
e.start_time,
273+
c10::Device(e.device_type, e.device_index),
274+
e.ptr,
275+
e.alloc_size,
276+
e.total_allocated,
277+
e.total_reserved);
278+
279+
kineto_events_.emplace_back();
280+
auto& evt = kineto_events_.back();
281+
evt.name(kMemoryEventName)
282+
.startUs(e.start_time)
283+
.deviceIndex(e.device_index)
284+
.deviceType(e.device_type)
285+
.nBytes(e.alloc_size)
286+
.startThreadId(e.threadID);
287+
}
288+
267289
for (const auto& e : op_events_) {
268290
if (e.end_us_ < e.start_us_) {
269291
// We initialize end_us_ to the smallest int64_t, so this means that
@@ -585,6 +607,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
585607
uint64_t start_time_;
586608
std::set<torch::profiler::impl::ActivityType> activities_;
587609
std::deque<OpEventData> op_events_;
610+
std::deque<MemoryEventData> memory_events_;
588611
torch::profiler::impl::kineto::TraceWrapper cpu_trace_;
589612
std::vector<KinetoEvent> kineto_events_;
590613
// Optional, if event post-processing is enabled.

0 commit comments

Comments
 (0)