Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage.
"""
import argparse
import csv
import statistics
import time
from collections.abc import Mapping
from datetime import datetime
import torch
from image_decoder import SAM2ImageDecoder
from image_encoder import SAM2ImageEncoder
from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from onnxruntime.transformers.io_binding_helper import CudaSession
class TestConfig:
def __init__(
self,
model_type: str,
onnx_path: str,
sam2_dir: str,
device: torch.device,
component: str = "image_encoder",
provider="CPUExecutionProvider",
torch_compile_mode="max-autotune",
batch_size: int = 1,
height: int = 1024,
width: int = 1024,
num_labels: int = 1,
num_points: int = 1,
num_masks: int = 1,
multi_mask_output: bool = False,
use_tf32: bool = True,
enable_cuda_graph: bool = False,
dtype=torch.float32,
prefer_nhwc: bool = False,
warm_up: int = 5,
enable_nvtx_profile: bool = False,
enable_ort_profile: bool = False,
enable_torch_profile: bool = False,
repeats: int = 1000,
verbose: bool = False,
):
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
assert height >= 160 and height <= 4096
assert width >= 160 and width <= 4096
self.model_type = model_type
self.onnx_path = onnx_path
self.sam2_dir = sam2_dir
self.component = component
self.provider = provider
self.torch_compile_mode = torch_compile_mode
self.batch_size = batch_size
self.height = height
self.width = width
self.num_labels = num_labels
self.num_points = num_points
self.num_masks = num_masks
self.multi_mask_output = multi_mask_output
self.device = device
self.use_tf32 = use_tf32
self.enable_cuda_graph = enable_cuda_graph
self.dtype = dtype
self.prefer_nhwc = prefer_nhwc
self.warm_up = warm_up
self.enable_nvtx_profile = enable_nvtx_profile
self.enable_ort_profile = enable_ort_profile
self.enable_torch_profile = enable_torch_profile
self.repeats = repeats
self.verbose = verbose
if self.component == "image_encoder":
assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder."
def __repr__(self):
return f"{vars(self)}"
def shape_dict(self) -> Mapping[str, list[int]]:
if self.component == "image_encoder":
return encoder_shape_dict(self.batch_size, self.height, self.width)
else:
return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks)
def random_inputs(self) -> Mapping[str, torch.Tensor]:
dtype = self.dtype
if self.component == "image_encoder":
return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)}
else:
return {
"image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device),
"image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device),
"image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device),
"point_coords": torch.randint(
0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device
),
"point_labels": torch.randint(
0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device
),
"input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device),
"has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device),
"original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device),
}
def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession:
if config.verbose:
print(f"create session for {vars(config)}")
if config.provider == "CUDAExecutionProvider":
device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
provider_options["use_tf32"] = int(config.use_tf32)
if config.prefer_nhwc:
provider_options["prefer_nhwc"] = 1
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
ort_session = InferenceSession(config.onnx_path, session_options, providers=providers)
return ort_session
def create_session(config: TestConfig, session_options=None) -> CudaSession:
ort_session = create_ort_session(config, session_options)
cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
cuda_session.allocate_buffers(config.shape_dict())
return cuda_session
class OrtTestSession:
"""A wrapper of ORT session to test relevance and performance."""
def __init__(self, config: TestConfig, session_options=None):
self.ort_session = create_session(config, session_options)
self.feed_dict = config.random_inputs()
def infer(self):
return self.ort_session.infer(self.feed_dict)
def measure_latency(cuda_session: CudaSession, input_dict):
start = time.time()
_ = cuda_session.infer(input_dict)
end = time.time()
return end - start
def run_torch(config: TestConfig):
device_type = config.device.type
is_cuda = device_type == "cuda"
# Turn on TF32 for Ampere GPUs which could help when data type is float32.
if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
enabled_auto_cast = is_cuda and config.dtype != torch.float32
ort_inputs = config.random_inputs()
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
if config.component == "image_encoder":
if is_cuda and config.torch_compile_mode != "none":
sam2_model.image_encoder.forward = torch.compile(
sam2_model.image_encoder.forward,
mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
fullgraph=True,
dynamic=False,
)
image_shape = config.shape_dict()["image"]
img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
sam2_encoder = SAM2ImageEncoder(sam2_model)
if is_cuda and config.torch_compile_mode != "none":
print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")
for _ in range(config.warm_up):
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
if is_cuda and config.enable_nvtx_profile:
import nvtx # noqa: PLC0415
from cuda import cudart # noqa: PLC0415
cudart.cudaProfilerStart()
print("Start nvtx profiling on encoder ...")
with nvtx.annotate("one_run"):
sam2_encoder(img, enable_nvtx_profile=True)
cudart.cudaProfilerStop()
if is_cuda and config.enable_torch_profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
print("Start torch profiling on encoder ...")
with torch.profiler.record_function("encoder"):
sam2_encoder(img)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_encoder.json")
if config.repeats == 0:
return
print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
for _ in range(config.repeats):
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
if is_cuda:
torch.cuda.synchronize()
else:
torch_inputs = (
ort_inputs["image_features_0"],
ort_inputs["image_features_1"],
ort_inputs["image_embeddings"],
ort_inputs["point_coords"],
ort_inputs["point_labels"],
ort_inputs["input_masks"],
ort_inputs["has_input_masks"],
ort_inputs["original_image_size"],
)
sam2_decoder = SAM2ImageDecoder(
sam2_model,
multimask_output=config.multi_mask_output,
)
if is_cuda and config.torch_compile_mode != "none":
sam2_decoder.forward = torch.compile(
sam2_decoder.forward,
mode=config.torch_compile_mode,
fullgraph=True,
dynamic=False,
)
# warm up
for _ in range(config.warm_up):
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
if is_cuda and config.enable_nvtx_profile:
import nvtx # noqa: PLC0415
from cuda import cudart # noqa: PLC0415
cudart.cudaProfilerStart()
print("Start nvtx profiling on decoder...")
with nvtx.annotate("one_run"):
sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
cudart.cudaProfilerStop()
if is_cuda and config.enable_torch_profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
print("Start torch profiling on decoder ...")
with torch.profiler.record_function("decoder"):
sam2_decoder(*torch_inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_decoder.json")
if config.repeats == 0:
return
print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
for _ in range(config.repeats):
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
if is_cuda:
torch.cuda.synchronize()
end = time.time()
return (end - start) / config.repeats
def run_test(
args: argparse.Namespace,
csv_writer: csv.DictWriter | None = None,
):
use_gpu: bool = args.use_gpu
enable_cuda_graph: bool = args.use_cuda_graph
repeats: int = args.repeats
if use_gpu:
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
provider = "CUDAExecutionProvider"
else:
device_id = 0
device = torch.device("cpu")
enable_cuda_graph = False
provider = "CPUExecutionProvider"
dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
config = TestConfig(
model_type=args.model_type,
onnx_path=args.onnx_path,
sam2_dir=args.sam2_dir,
component=args.component,
provider=provider,
batch_size=args.batch_size,
height=args.height,
width=args.width,
device=device,
use_tf32=True,
enable_cuda_graph=enable_cuda_graph,
dtype=dtypes[args.dtype],
prefer_nhwc=args.prefer_nhwc,
repeats=args.repeats,
warm_up=args.warm_up,
enable_nvtx_profile=args.enable_nvtx_profile,
enable_ort_profile=args.enable_ort_profile,
enable_torch_profile=args.enable_torch_profile,
torch_compile_mode=args.torch_compile_mode,
verbose=False,
)
if args.engine == "ort":
sess_options = SessionOptions()
sess_options.intra_op_num_threads = args.intra_op_num_threads
if config.enable_ort_profile:
sess_options.enable_profiling = True
sess_options.log_severity_level = 4
sess_options.log_verbosity_level = 0
session = create_session(config, sess_options)
input_dict = config.random_inputs()
# warm up session
try:
for _ in range(config.warm_up):
_ = measure_latency(session, input_dict)
except Exception as e:
print(f"Failed to run {config=}. Exception: {e}")
return
if config.enable_nvtx_profile:
import nvtx # noqa: PLC0415
from cuda import cudart # noqa: PLC0415
cudart.cudaProfilerStart()
with nvtx.annotate("one_run"):
_ = session.infer(input_dict)
cudart.cudaProfilerStop()
if config.enable_ort_profile:
session.ort_session.end_profiling()
if repeats == 0:
return
latency_list = []
for _ in range(repeats):
latency = measure_latency(session, input_dict)
latency_list.append(latency)
average_latency = statistics.mean(latency_list)
del session
else: # torch
with torch.no_grad():
try:
average_latency = run_torch(config)
except Exception as e:
print(f"Failed to run {config=}. Exception: {e}")
return
if repeats == 0:
return
engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
row = {
"model_type": args.model_type,
"component": args.component,
"dtype": args.dtype,
"use_gpu": use_gpu,
"enable_cuda_graph": enable_cuda_graph,
"prefer_nhwc": config.prefer_nhwc,
"use_tf32": config.use_tf32,
"batch_size": args.batch_size,
"height": args.height,
"width": args.width,
"multi_mask_output": args.multimask_output,
"num_labels": config.num_labels,
"num_points": config.num_points,
"num_masks": config.num_masks,
"intra_op_num_threads": args.intra_op_num_threads,
"warm_up": config.warm_up,
"repeats": repeats,
"enable_nvtx_profile": args.enable_nvtx_profile,
"torch_compile_mode": args.torch_compile_mode,
"engine": engine,
"average_latency": average_latency,
}
if csv_writer is not None:
csv_writer.writerow(row)
print(f"{vars(config)}")
print(f"{row}")
def run_perf_test(args):
features = "gpu" if args.use_gpu else "cpu"
csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
features,
args.engine,
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"model_type",
"component",
"dtype",
"use_gpu",
"enable_cuda_graph",
"prefer_nhwc",
"use_tf32",
"batch_size",
"height",
"width",
"multi_mask_output",
"num_labels",
"num_points",
"num_masks",
"intra_op_num_threads",
"warm_up",
"repeats",
"enable_nvtx_profile",
"torch_compile_mode",
"engine",
"average_latency",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
run_test(args, csv_writer)
def _parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.")
parser.add_argument(
"--component",
required=False,
choices=["image_encoder", "image_decoder"],
default="image_encoder",
help="component to benchmark. Choices are image_encoder and image_decoder.",
)
parser.add_argument(
"--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference."
)
parser.add_argument(
"--use_gpu",
required=False,
action="store_true",
help="Use GPU for inference.",
)
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--use_cuda_graph",
required=False,
action="store_true",
help="Use cuda graph in onnxruntime.",
)
parser.set_defaults(use_cuda_graph=False)
parser.add_argument(
"--intra_op_num_threads",
required=False,
type=int,
choices=[0, 1, 2, 4, 8, 16],
default=0,
help="intra_op_num_threads for onnxruntime. ",
)
parser.add_argument(
"--batch_size",
required=False,
type=int,
default=1,
help="batch size",
)
parser.add_argument(
"--height",
required=False,
type=int,
default=1024,
help="image height",
)
parser.add_argument(
"--width",
required=False,
type=int,
default=1024,
help="image width",
)
parser.add_argument(
"--repeats",
required=False,
type=int,
default=1000,
help="number of repeats for performance test. Default is 1000.",
)
parser.add_argument(
"--warm_up",
required=False,
type=int,
default=5,
help="number of runs for warm up. Default is 5.",
)
parser.add_argument(
"--engine",
required=False,
type=str,
default="ort",
choices=["ort", "torch"],
help="engine for inference",
)
parser.add_argument(
"--multimask_output",
required=False,
default=False,
action="store_true",
help="Export mask_decoder or image_decoder with multimask_output",
)
parser.add_argument(
"--prefer_nhwc",
required=False,
default=False,
action="store_true",
help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
)
parser.add_argument(
"--enable_nvtx_profile",
required=False,
default=False,
action="store_true",
help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
)
parser.add_argument(
"--enable_ort_profile",
required=False,
default=False,
action="store_true",
help="Enable ORT profiling.",
)
parser.add_argument(
"--enable_torch_profile",
required=False,
default=False,
action="store_true",
help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
)
parser.add_argument(
"--model_type",
required=False,
type=str,
default="sam2_hiera_large",
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
help="sam2 model name",
)
parser.add_argument(
"--sam2_dir",
required=False,
type=str,
default="./segment-anything-2",
help="The directory of segment-anything-2 git root directory",
)
parser.add_argument(
"--onnx_path",
required=False,
type=str,
default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx",
help="path of onnx model",
)
parser.add_argument(
"--torch_compile_mode",
required=False,
type=str,
default=None,
choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
help="torch compile mode. none will disable torch compile.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _parse_arguments()
print(f"arguments:{args}")
if args.torch_compile_mode is None:
# image decoder will fail with compile modes other than "none".
args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"
if args.use_gpu:
assert torch.cuda.is_available()
if args.engine == "ort":
assert "CUDAExecutionProvider" in get_available_providers()
args.enable_torch_profile = False
else:
# Only support cuda profiling for now.
assert not args.enable_nvtx_profile
assert not args.enable_torch_profile
if args.enable_nvtx_profile or args.enable_torch_profile:
run_test(args)
else:
run_perf_test(args)