Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import os
import sys
from collections.abc import Mapping
import torch
from sam2.build_sam import build_sam2
from sam2.modeling.sam2_base import SAM2Base
logger = logging.getLogger(__name__)
def _get_model_cfg(model_type) -> str:
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
if model_type == "sam2_hiera_tiny":
model_cfg = "sam2_hiera_t.yaml"
elif model_type == "sam2_hiera_small":
model_cfg = "sam2_hiera_s.yaml"
elif model_type == "sam2_hiera_base_plus":
model_cfg = "sam2_hiera_b+.yaml"
else:
model_cfg = "sam2_hiera_l.yaml"
return model_cfg
def load_sam2_model(sam2_dir, model_type, device: str | torch.device = "cpu") -> SAM2Base:
checkpoints_dir = os.path.join(sam2_dir, "checkpoints")
sam2_config_dir = os.path.join(sam2_dir, "sam2_configs")
if not os.path.exists(sam2_dir):
raise FileNotFoundError(f"{sam2_dir} does not exist. Please specify --sam2_dir correctly.")
if not os.path.exists(checkpoints_dir):
raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.")
if not os.path.exists(sam2_config_dir):
raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.")
checkpoint_path = os.path.join(checkpoints_dir, f"{model_type}.pt")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"{checkpoint_path} does not exist. Please download checkpoints under the directory.")
if sam2_dir not in sys.path:
sys.path.append(sam2_dir)
model_cfg = _get_model_cfg(model_type)
sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
return sam2_model
def sam2_onnx_path(output_dir, model_type, component, multimask_output=False, suffix=""):
if component == "image_encoder":
return os.path.join(output_dir, f"{model_type}_image_encoder{suffix}.onnx")
elif component == "mask_decoder":
return os.path.join(output_dir, f"{model_type}_mask_decoder{suffix}.onnx")
elif component == "prompt_encoder":
return os.path.join(output_dir, f"{model_type}_prompt_encoder{suffix}.onnx")
else:
assert component == "image_decoder"
return os.path.join(
output_dir, f"{model_type}_image_decoder" + ("_multi" if multimask_output else "") + f"{suffix}.onnx"
)
def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, list[int]]:
assert height == 1024 and width == 1024, "Only 1024x1024 images are supported."
return {
"image": [batch_size, 3, height, width],
"image_features_0": [batch_size, 32, height // 4, width // 4],
"image_features_1": [batch_size, 64, height // 8, width // 8],
"image_embeddings": [batch_size, 256, height // 16, width // 16],
}
def decoder_shape_dict(
original_image_height: int,
original_image_width: int,
num_labels: int = 1,
max_points: int = 16,
num_masks: int = 1,
) -> dict:
height: int = 1024
width: int = 1024
return {
"image_features_0": [1, 32, height // 4, width // 4],
"image_features_1": [1, 64, height // 8, width // 8],
"image_embeddings": [1, 256, height // 16, width // 16],
"point_coords": [num_labels, max_points, 2],
"point_labels": [num_labels, max_points],
"input_masks": [num_labels, 1, height // 4, width // 4],
"has_input_masks": [num_labels],
"original_image_size": [2],
"masks": [num_labels, num_masks, original_image_height, original_image_width],
"iou_predictions": [num_labels, num_masks],
"low_res_masks": [num_labels, num_masks, height // 4, width // 4],
}
def compare_tensors_with_tolerance(
name: str,
tensor1: torch.Tensor,
tensor2: torch.Tensor,
atol=5e-3,
rtol=1e-4,
mismatch_percentage_tolerance=0.1,
) -> bool:
assert tensor1.shape == tensor2.shape
a = tensor1.clone().float()
b = tensor2.clone().float()
differences = torch.abs(a - b)
mismatch_count = (differences > (rtol * torch.max(torch.abs(a), torch.abs(b)) + atol)).sum().item()
total_elements = a.numel()
mismatch_percentage = (mismatch_count / total_elements) * 100
passed = mismatch_percentage < mismatch_percentage_tolerance
log_func = logger.error if not passed else logger.info
log_func(
"%s: mismatched elements percentage %.2f (%d/%d). Verification %s (threshold=%.2f).",
name,
mismatch_percentage,
mismatch_count,
total_elements,
"passed" if passed else "failed",
mismatch_percentage_tolerance,
)
return passed
def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor:
image = torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32).cpu()
return image
def setup_logger(verbose=True):
if verbose:
logging.basicConfig(format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")
logging.getLogger().setLevel(logging.INFO)
else:
logging.basicConfig(format="[%(message)s")
logging.getLogger().setLevel(logging.WARNING)