Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch3d / ops / vert_align.py
Size: Mime:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import torch
import torch.nn.functional as F


def vert_align(
    feats,
    verts,
    return_packed: bool = False,
    interp_mode: str = "bilinear",
    padding_mode: str = "zeros",
    align_corners: bool = True,
) -> torch.Tensor:
    """
    Sample vertex features from a feature map. This operation is called
    "perceptual feature pooling" in [1] or "vert align" in [2].

    [1] Wang et al, "Pixel2Mesh: Generating 3D Mesh Models from Single
        RGB Images", ECCV 2018.
    [2] Gkioxari et al, "Mesh R-CNN", ICCV 2019

    Args:
        feats: FloatTensor of shape (N, C, H, W) representing image features
            from which to sample or a list of features each with potentially
            different C, H or W dimensions.
        verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds)
            with `verts_padded' or `points_padded' as an attribute giving the (x, y, z)
            vertex positions for which to sample. (x, y) verts should be normalized such
            that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right
            location in the input feature map.
        return_packed: (bool) Indicates whether to return packed features
        interp_mode: (str) Specifies how to interpolate features.
            ('bilinear' or 'nearest')
        padding_mode: (str) Specifies how to handle vertices outside of the
            [-1, 1] range. ('zeros', 'reflection', or 'border')
        align_corners (bool): Geometrically, we consider the pixels of the
            input  as squares rather than points.
            If set to ``True``, the extrema (``-1`` and ``1``) are considered as
            referring to the center points of the input's corner pixels. If set
            to ``False``, they are instead considered as referring to the corner
            points of the input's corner pixels, making the sampling more
            resolution agnostic. Default: ``True``

    Returns:
        feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for each
            vertex. If feats is a list, we return concatenated features in axis=2 of
            shape (N, V, sum(C_n)) where C_n = feats[n].shape[1].
            If return_packed = True, the features are transformed to a packed
            representation of shape (sum(V), C)
    """
    if torch.is_tensor(verts):
        if verts.dim() != 3:
            raise ValueError("verts tensor should be 3 dimensional")
        grid = verts
    elif hasattr(verts, "verts_padded"):
        grid = verts.verts_padded()
    elif hasattr(verts, "points_padded"):
        grid = verts.points_padded()
    else:
        raise ValueError(
            "verts must be a tensor or have a "
            + "`points_padded' or`verts_padded` attribute."
        )

    grid = grid[:, None, :, :2]  # (N, 1, V, 2)

    if torch.is_tensor(feats):
        feats = [feats]
    for feat in feats:
        if feat.dim() != 4:
            raise ValueError("feats must have shape (N, C, H, W)")
        if grid.shape[0] != feat.shape[0]:
            raise ValueError("inconsistent batch dimension")

    feats_sampled = []
    for feat in feats:
        feat_sampled = F.grid_sample(
            feat,
            grid,
            mode=interp_mode,
            padding_mode=padding_mode,
            align_corners=align_corners,
        )  # (N, C, 1, V)
        feat_sampled = feat_sampled.squeeze(dim=2).transpose(1, 2)  # (N, V, C)
        feats_sampled.append(feat_sampled)
    feats_sampled = torch.cat(feats_sampled, dim=2)  # (N, V, sum(C))

    if return_packed:
        # flatten the first two dimensions: (N*V, C)
        feats_sampled = feats_sampled.view(-1, feats_sampled.shape[-1])
        if hasattr(verts, "verts_padded_to_packed_idx"):
            idx = (
                verts.verts_padded_to_packed_idx()
                .view(-1, 1)
                .expand(-1, feats_sampled.shape[-1])
            )
            feats_sampled = feats_sampled.gather(0, idx)  # (sum(V), C)

    return feats_sampled