Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch3d / implicitron / models / model_dbir.py
Size: Mime:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


from typing import Any, Dict, List, Optional, Tuple

import torch
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.implicitron.tools.point_cloud_utils import (
    get_rgbd_point_cloud,
    render_point_cloud_pytorch3d,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds

from .base_model import ImplicitronModelBase, ImplicitronRender
from .renderer.base import EvaluationMode


@registry.register
class ModelDBIR(ImplicitronModelBase):
    """
    A simple depth-based image rendering model.

    Args:
        render_image_width: The width of the rendered rectangular images.
        render_image_height: The height of the rendered rectangular images.
        bg_color: The color of the background.
        max_points: Maximum number of points in the point cloud
            formed by unprojecting all source view depths.
            If more points are present, they are randomly subsampled
            to this number of points without replacement.
    """

    render_image_width: int = 256
    render_image_height: int = 256
    bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
    max_points: int = -1

    # pyre-fixme[14]: `forward` overrides method defined in `ImplicitronModelBase`
    #  inconsistently.
    def forward(
        self,
        *,  # force keyword-only arguments
        image_rgb: Optional[torch.Tensor],
        camera: CamerasBase,
        fg_probability: Optional[torch.Tensor],
        mask_crop: Optional[torch.Tensor],
        depth_map: Optional[torch.Tensor],
        sequence_name: Optional[List[str]],
        evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
        frame_type: List[str],
        **kwargs,
    ) -> Dict[str, Any]:  # TODO: return a namedtuple or dataclass
        """
        Given a set of input source cameras images and depth maps, unprojects
        all RGBD maps to a colored point cloud and renders into the target views.

        Args:
            camera: A batch of `N` PyTorch3D cameras.
            image_rgb: A batch of `N` images of shape `(N, 3, H, W)`.
            depth_map: A batch of `N` depth maps of shape `(N, 1, H, W)`.
            fg_probability: A batch of `N` foreground probability maps
                of shape `(N, 1, H, W)`.
            frame_type: A list of `N` strings containing frame type indicators
                which specify target and source views.

        Returns:
            preds: A dict with the following fields:
                implicitron_render: The rendered colors, depth and mask
                    of the target views.
                point_cloud: The point cloud of the scene. It's renders are
                    stored in `implicitron_render`.
        """

        if image_rgb is None:
            raise ValueError("ModelDBIR needs image input")

        if fg_probability is None:
            raise ValueError("ModelDBIR needs foreground mask input")

        if depth_map is None:
            raise ValueError("ModelDBIR needs depth map input")

        is_known = is_known_frame(frame_type)
        is_known_idx = torch.where(is_known)[0]

        mask_fg = (fg_probability > 0.5).type_as(image_rgb)

        point_cloud = get_rgbd_point_cloud(
            # pyre-fixme[6]: For 1st param expected `Union[List[int], int,
            #  LongTensor]` but got `Tensor`.
            camera[is_known_idx],
            image_rgb[is_known_idx],
            depth_map[is_known_idx],
            mask_fg[is_known_idx],
        )

        pcl_size = point_cloud.num_points_per_cloud().item()
        if (self.max_points > 0) and (pcl_size > self.max_points):
            # pyre-fixme[6]: For 1st param expected `int` but got `Union[bool,
            #  float, int]`.
            prm = torch.randperm(pcl_size)[: self.max_points]
            point_cloud = Pointclouds(
                point_cloud.points_padded()[:, prm, :],
                # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
                features=point_cloud.features_padded()[:, prm, :],
            )

        is_target_idx = torch.where(~is_known)[0]

        depth_render, image_render, mask_render = [], [], []

        # render into target frames in a for loop to save memory
        for tgt_idx in is_target_idx:
            _image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
                camera[int(tgt_idx)],
                point_cloud,
                render_size=(self.render_image_height, self.render_image_width),
                point_radius=1e-2,
                topk=10,
                bg_color=self.bg_color,
            )
            _image_render = _image_render.clamp(0.0, 1.0)
            # the mask is the set of pixels with opacity bigger than eps
            _mask_render = (_mask_render > 1e-4).float()

            depth_render.append(_depth_render)
            image_render.append(_image_render)
            mask_render.append(_mask_render)

        implicitron_render = ImplicitronRender(
            **{
                k: torch.cat(v, dim=0)
                for k, v in zip(
                    ["depth_render", "image_render", "mask_render"],
                    [depth_render, image_render, mask_render],
                )
            }
        )

        preds = {
            "implicitron_render": implicitron_render,
            "point_cloud": point_cloud,
        }

        return preds