Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
from torch import nn
import onnxruntime
logger = logging.getLogger(__name__)
class SAM2ImageEncoder(nn.Module):
def __init__(self, sam_model: SAM2Base) -> None:
super().__init__()
self.model = sam_model
self.image_encoder = sam_model.image_encoder
self.no_mem_embed = sam_model.no_mem_embed
def forward(
self,
image: torch.Tensor,
enable_nvtx_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Encodes images into features.
Only supports H=W=1024. If you want to use different image sizes like 512x512,
see https://github.com/facebookresearch/segment-anything-2/issues/138.
Args:
image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
enable_nvtx_profile (bool): enable NVTX profiling.
Returns:
image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
"""
nvtx_helper = None
if enable_nvtx_profile:
from nvtx_helper import NvtxHelper # noqa: PLC0415
nvtx_helper = NvtxHelper(["image_encoder", "post_process"])
if nvtx_helper is not None:
nvtx_helper.start_profile("image_encoder")
backbone_out = self.image_encoder(image)
if nvtx_helper is not None:
nvtx_helper.stop_profile("image_encoder")
nvtx_helper.start_profile("post_process")
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
# Prepare and flatten visual features.
feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
# TODO: we should avoid this transpose since it will be transposed back to NCHW later.
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
feats = [
feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1], strict=False)
][::-1]
if nvtx_helper is not None:
nvtx_helper.stop_profile("post_process")
nvtx_helper.print_latency()
return feats[0], feats[1], feats[2]
def export_image_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
dynamic_batch_axes: bool = False,
verbose: bool = False,
dynamo: bool = False,
clear_dynamo_metadata: bool = False,
):
image = random_sam2_input_image()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
logger.info("image.shape: %s", image.shape)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
dynamic_axes = None
if dynamic_batch_axes:
dynamic_axes = {
"image": {0: "batch_size"},
"image_features_0": {0: "batch_size"},
"image_features_1": {0: "batch_size"},
"image_embeddings": {0: "batch_size"},
}
with warnings.catch_warnings():
if not verbose:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if not dynamo:
torch.onnx.export(
sam2_encoder,
image,
onnx_model_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["image"],
output_names=["image_features_0", "image_features_1", "image_embeddings"],
dynamic_axes=dynamic_axes,
)
else:
torch._dynamo.config.capture_scalar_outputs = True
ep = torch.export.export(
sam2_encoder,
args=(image,),
strict=False,
dynamic_shapes=[
{0: torch.export.Dim.AUTO},
],
)
onnx_program = torch.onnx.export(
ep,
(),
opset_version=17,
input_names=["image"],
output_names=["image_features_0", "image_features_1", "image_embeddings"],
dynamo=True,
)
onnx_program.optimize()
onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False)
import onnx # noqa: PLC0415
from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper # noqa: PLC0415
onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True)
if dynamic_batch_axes:
# Fix labels of dynamic axes since they can't be specified during Dynamo export currently
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size"
for i in range(3):
onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size"
onnx_model_helper = DynamoOnnxHelper(onnx_model)
onnx_model_helper.convert_constants_to_initializers()
if clear_dynamo_metadata:
onnx_model_helper.clear_metadata()
import os # noqa: PLC0415
if os.path.exists(onnx_model_path):
os.remove(onnx_model_path)
if os.path.exists(onnx_model_path + ".data"):
os.remove(onnx_model_path + ".data")
onnx_model_helper.model.save_model_to_file(
onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True
)
print("encoder onnx model saved to", onnx_model_path)
def test_image_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
dynamic_batch_axes=False,
):
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
batch_sizes = [1, 2] if dynamic_batch_axes else [1]
for batch_size in batch_sizes:
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
logger.info("image.shape: %s", image.shape)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
outputs = ort_session.run(output_names, {"image": image.numpy()})
for i, output_name in enumerate(output_names):
logger.info("output %s shape %s", output_name, outputs[i].shape)
ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
# ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
if (
compare_tensors_with_tolerance(
"image_features_0",
image_features_0,
torch.tensor(ort_image_features_0),
mismatch_percentage_tolerance=1,
)
and compare_tensors_with_tolerance(
"image_features_1",
image_features_1,
torch.tensor(ort_image_features_1),
mismatch_percentage_tolerance=1,
)
and compare_tensors_with_tolerance(
"image_embeddings",
image_embeddings,
torch.tensor(ort_image_embeddings),
mismatch_percentage_tolerance=1,
)
):
print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
else:
print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")