Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance
from torch import nn
logger = logging.getLogger(__name__)
class SAM2PromptEncoder(nn.Module):
def __init__(self, sam_model: SAM2Base):
super().__init__()
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
@torch.no_grad()
def forward(
self,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
input_masks: torch.Tensor,
has_input_masks: torch.Tensor,
):
"""Encode prompts.
Args:
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
coordinate in (x, y) format of the P input points in image of size 1024x1024.
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
positive (foreground), 0 means negative (background), -1 means padding,
2 (box left upper corner), 3 (box right bottom corner).
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
Typically coming from a previous iteration.
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
Returns:
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
"""
sparse_embeddings = self._embed_points(point_coords, point_labels)
dense_embeddings = self._embed_masks(input_masks, has_input_masks)
image_pe = self.prompt_encoder.get_dense_pe()
return sparse_embeddings, dense_embeddings, image_pe
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
point_coords = torch.cat([point_coords, padding_point], dim=1)
point_labels = torch.cat([point_labels, padding_label], dim=1)
# Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
for i in range(self.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
logger.info("mask_embedding.shape: %s", mask_embedding.shape)
return mask_embedding
def export_prompt_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
num_labels = 2
num_points = 3
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
logger.info("point_coords.shape: %s", point_coords.shape)
logger.info("point_labels.shape: %s", point_labels.shape)
logger.info("input_masks.shape: %s", input_masks.shape)
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
logger.info("image_pe.shape: %s", image_pe.shape)
torch.onnx.export(
sam2_prompt_encoder,
(point_coords, point_labels, input_masks, has_input_masks),
onnx_model_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
dynamic_axes={
"point_coords": {0: "num_labels", 1: "num_points"},
"point_labels": {0: "num_labels", 1: "num_points"},
"input_masks": {0: "num_labels"},
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
"dense_embeddings": {0: "num_labels"},
},
)
print("prompt encoder onnx model saved to ", onnx_model_path)
def test_prompt_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
num_labels = 1
num_points = 5
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
import onnxruntime # noqa: PLC0415
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
outputs = ort_session.run(
output_names,
{
"point_coords": point_coords.numpy(),
"point_labels": point_labels.numpy(),
"input_masks": input_masks.numpy(),
"has_input_masks": has_input_masks.numpy(),
},
)
for i, output_name in enumerate(output_names):
logger.info("output %s shape: %s", output_name, outputs[i].shape)
ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
if (
compare_tensors_with_tolerance(
"sparse_embeddings",
sparse_embeddings,
torch.tensor(ort_sparse_embeddings),
mismatch_percentage_tolerance=0.2,
)
and compare_tensors_with_tolerance(
"dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
)
and compare_tensors_with_tolerance(
"image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
)
):
print(f"onnx model has been verified: {onnx_model_path}")
else:
print(f"onnx model verification failed: {onnx_model_path}")