Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import nvtx
from cuda import cudart
class NvtxHelper:
def __init__(self, stages):
self.stages = stages
self.events = {}
for stage in stages:
for marker in ["start", "stop"]:
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
self.markers = {}
def start_profile(self, stage, color="blue"):
self.markers[stage] = nvtx.start_range(message=stage, color=color)
event_name = stage + "-start"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
def stop_profile(self, stage):
event_name = stage + "-stop"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
nvtx.end_range(self.markers[stage])
def print_latency(self):
for stage in self.stages:
latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1]
print(f"{stage}: {latency:.2f} ms")