Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch3d / implicitron / tools / depth_cleanup.py
Size: Mime:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import torch.nn.functional as Fu
from pytorch3d.ops import wmean
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds


def cleanup_eval_depth(
    point_cloud: Pointclouds,
    camera: CamerasBase,
    depth: torch.Tensor,
    mask: torch.Tensor,
    sigma: float = 0.01,
    image=None,
):
    ba, _, H, W = depth.shape

    pcl = point_cloud.points_padded()
    n_pts = point_cloud.num_points_per_cloud()
    pcl_mask = (
        torch.arange(pcl.shape[1], dtype=torch.int64, device=pcl.device)[None]
        < n_pts[:, None]
    ).type_as(pcl)

    pcl_proj = camera.transform_points(pcl, eps=1e-2)[..., :-1]
    pcl_depth = camera.get_world_to_view_transform().transform_points(pcl)[..., -1]

    depth_and_idx = torch.cat(
        (
            depth,
            torch.arange(H * W).view(1, 1, H, W).expand(ba, 1, H, W).type_as(depth),
        ),
        dim=1,
    )

    depth_and_idx_sampled = Fu.grid_sample(
        depth_and_idx, -pcl_proj[:, None], mode="nearest"
    )[:, :, 0].view(ba, 2, -1)

    depth_sampled, idx_sampled = depth_and_idx_sampled.split([1, 1], dim=1)
    df = (depth_sampled[:, 0] - pcl_depth).abs()

    # the threshold is a sigma-multiple of the standard deviation of the depth
    mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1)
    std = (
        # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
        wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1))
        .clamp(1e-4)
        .sqrt()
        .view(ba, -1)
    )
    good_df_thr = std * sigma
    good_depth = (df <= good_df_thr).float() * pcl_mask

    # perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
    # print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')

    good_depth_raster = torch.zeros_like(depth).view(ba, -1)
    good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth)

    good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float()

    # if float(torch.rand(1)) > 0.95:
    #     depth_ok = depth * good_depth_mask

    #     # visualize
    #     visdom_env = 'depth_cleanup_dbg'
    #     from visdom import Visdom
    #     # from tools.vis_utils import make_depth_image
    #     from pytorch3d.vis.plotly_vis import plot_scene
    #     viz = Visdom()

    #     show_pcls = {
    #         'pointclouds': point_cloud,
    #     }
    #     for d, nm in zip(
    #         (depth, depth_ok),
    #         ('pointclouds_unproj', 'pointclouds_unproj_ok'),
    #     ):
    #         pointclouds_unproj = get_rgbd_point_cloud(
    #             camera, image, d,
    #         )
    #         if int(pointclouds_unproj.num_points_per_cloud()) > 0:
    #             show_pcls[nm] = pointclouds_unproj

    #     scene_dict = {'1': {
    #         **show_pcls,
    #         'cameras': camera,
    #     }}
    #     scene = plot_scene(
    #         scene_dict,
    #         pointcloud_max_points=5000,
    #         pointcloud_marker_size=1.5,
    #         camera_scale=1.0,
    #     )
    #     viz.plotlyplot(scene, env=visdom_env, win='scene')

    #     # depth_image_ok = make_depth_image(depths_ok, masks)
    #     # viz.images(depth_image_ok, env=visdom_env, win='depth_ok')
    #     # depth_image = make_depth_image(depths, masks)
    #     # viz.images(depth_image, env=visdom_env, win='depth')
    #     # # viz.images(rgb_rendered, env=visdom_env, win='images_render')
    #     # viz.images(images, env=visdom_env, win='images')
    #     import pdb; pdb.set_trace()

    return good_depth_mask