Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch3d / implicitron / models / visualization / render_flyaround.py
Size: Mime:
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging
import math
import os
import random
from typing import (
    Any,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    TYPE_CHECKING,
    Union,
)

import numpy as np
import torch
import torch.nn.functional as Fu
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
from pytorch3d.implicitron.tools.eval_video_trajectory import (
    generate_eval_video_cameras,
)
from pytorch3d.implicitron.tools.video_writer import VideoWriter
from pytorch3d.implicitron.tools.vis_utils import (
    get_visdom_connection,
    make_depth_image,
)
from tqdm import tqdm

if TYPE_CHECKING:
    from visdom import Visdom

logger = logging.getLogger(__name__)


def render_flyaround(
    dataset: DatasetBase,
    sequence_name: str,
    model: torch.nn.Module,
    output_video_path: str,
    n_flyaround_poses: int = 40,
    fps: int = 20,
    trajectory_type: str = "circular_lsq_fit",
    max_angle: float = 2 * math.pi,
    trajectory_scale: float = 1.1,
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
    up: Tuple[float, float, float] = (0.0, -1.0, 0.0),
    traj_offset: float = 0.0,
    n_source_views: int = 9,
    visdom_show_preds: bool = False,
    visdom_environment: str = "render_flyaround",
    visdom_server: str = "http://127.0.0.1",
    visdom_port: int = 8097,
    num_workers: int = 10,
    device: Union[str, torch.device] = "cuda",
    seed: Optional[int] = None,
    video_resize: Optional[Tuple[int, int]] = None,
    output_video_frames_dir: Optional[str] = None,
    visualize_preds_keys: Sequence[str] = (
        "images_render",
        "masks_render",
        "depths_render",
        "_all_source_images",
    ),
) -> None:
    """
    Uses `model` to generate a video consisting of renders of a scene imaged from
    a camera flying around the scene. The scene is specified with the `dataset` object and
    `sequence_name` which denotes the name of the scene whose frames are in `dataset`.

    Args:
        dataset: The dataset object containing frames from a sequence in `sequence_name`.
        sequence_name: Name of a sequence from `dataset`.
        model: The model whose predictions are going to be visualized.
        output_video_path: The path to the video output by this script.
        n_flyaround_poses: The number of camera poses of the flyaround trajectory.
        fps: Framerate of the output video.
        trajectory_type: The type of the camera trajectory. Can be one of:
            circular_lsq_fit: Camera centers follow a trajectory obtained
                by fitting a 3D circle to train_cameras centers.
                All cameras are looking towards scene_center.
            figure_eight: Figure-of-8 trajectory around the center of the
                central camera of the training dataset.
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
                of a figure-eight knot
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
        trajectory_type: The type of the camera trajectory. Can be one of:
            circular_lsq_fit: Camera centers follow a trajectory obtained
                by fitting a 3D circle to train_cameras centers.
                All cameras are looking towards scene_center.
            figure_eight: Figure-of-8 trajectory around the center of the
                central camera of the training dataset.
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
                of a figure-eight knot
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
        max_angle: Defines the total length of the generated camera trajectory.
            All possible trajectories (set with the `trajectory_type` argument) are
            periodic with the period of `time==2pi`.
            E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi` will generate
            a trajectory of camera poses rotating the total of 720 deg around the object.
        trajectory_scale: The extent of the trajectory.
        scene_center: The center of the scene in world coordinates which all
            the cameras from the generated trajectory look at.
        up: The "up" vector of the scene (=the normal of the scene floor).
            Active for the `trajectory_type="circular"`.
        traj_offset: 3D offset vector added to each point of the trajectory.
        n_source_views: The number of source views sampled from the known views of the
            training sequence added to each evaluation batch.
        visdom_show_preds: If `True`, exports the visualizations to visdom.
        visdom_environment: The name of the visdom environment.
        visdom_server: The address of the visdom server.
        visdom_port: The visdom port.
        num_workers: The number of workers used to load the training data.
        seed: The random seed used for reproducible sampling of the source views.
        video_resize: Optionally, defines the size of the output video.
        output_video_frames_dir: If specified, the frames of the output video are going
            to be permanently stored in this directory.
        visualize_preds_keys: The names of the model predictions to visualize.
    """

    if seed is None:
        seed = hash(sequence_name)

    if visdom_show_preds:
        viz = get_visdom_connection(server=visdom_server, port=visdom_port)
    else:
        viz = None

    logger.info(f"Loading all data of sequence '{sequence_name}'.")
    seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
    train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
    assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
    # pyre-ignore[6]
    sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
    logger.info(f"Sequence set = {sequence_set_name}.")
    train_cameras = train_data.camera
    time = torch.linspace(0, max_angle, n_flyaround_poses + 1)[:n_flyaround_poses]
    test_cameras = generate_eval_video_cameras(
        train_cameras,
        time=time,
        n_eval_cams=n_flyaround_poses,
        trajectory_type=trajectory_type,
        trajectory_scale=trajectory_scale,
        scene_center=scene_center,
        up=up,
        focal_length=None,
        principal_point=torch.zeros(n_flyaround_poses, 2),
        traj_offset_canonical=(0.0, 0.0, traj_offset),
    )

    # sample the source views reproducibly
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        source_views_i = torch.randperm(len(seq_idx))[:n_source_views]

    # add the first dummy view that will get replaced with the target camera
    source_views_i = Fu.pad(source_views_i, [1, 0])
    source_views = [seq_idx[i] for i in source_views_i.tolist()]
    batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
    assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)

    preds_total = []
    for n in tqdm(range(n_flyaround_poses), total=n_flyaround_poses):
        # set the first batch camera to the target camera
        for k in ("R", "T", "focal_length", "principal_point"):
            getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)

        # Move to cuda
        net_input = batch.to(device)
        with torch.no_grad():
            preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})

            # make sure we dont overwrite something
            assert all(k not in preds for k in net_input.keys())
            preds.update(net_input)  # merge everything into one big dict

            # Render the predictions to images
            rendered_pred = _images_from_preds(preds, extract_keys=visualize_preds_keys)
            preds_total.append(rendered_pred)

            # show the preds every 5% of the export iterations
            if visdom_show_preds and (
                n % max(n_flyaround_poses // 20, 1) == 0 or n == n_flyaround_poses - 1
            ):
                assert viz is not None
                _show_predictions(
                    preds_total,
                    sequence_name=batch.sequence_name[0],
                    viz=viz,
                    viz_env=visdom_environment,
                    predicted_keys=visualize_preds_keys,
                )

    logger.info(f"Exporting videos for sequence {sequence_name} ...")
    _generate_prediction_videos(
        preds_total,
        sequence_name=batch.sequence_name[0],
        viz=viz,
        viz_env=visdom_environment,
        fps=fps,
        video_path=output_video_path,
        resize=video_resize,
        video_frames_dir=output_video_frames_dir,
        predicted_keys=visualize_preds_keys,
    )


def _load_whole_dataset(
    dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
) -> FrameData:
    load_all_dataloader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(dataset, idx),
        batch_size=len(idx),
        num_workers=num_workers,
        shuffle=False,
        collate_fn=FrameData.collate,
    )
    return next(iter(load_all_dataloader))


def _images_from_preds(
    preds: Dict[str, Any],
    extract_keys: Iterable[str] = (
        "image_rgb",
        "images_render",
        "fg_probability",
        "masks_render",
        "depths_render",
        "depth_map",
        "_all_source_images",
    ),
) -> Dict[str, torch.Tensor]:
    imout = {}
    for k in extract_keys:
        if k == "_all_source_images" and "image_rgb" in preds:
            src_ims = preds["image_rgb"][1:].cpu().detach().clone()
            v = _stack_images(src_ims, None)[None]
        else:
            if k not in preds or preds[k] is None:
                print(f"cant show {k}")
                continue
            v = preds[k].cpu().detach().clone()
        if k.startswith("depth"):
            mask_resize = Fu.interpolate(
                preds["masks_render"],
                size=preds[k].shape[2:],
                mode="nearest",
            )
            v = make_depth_image(preds[k], mask_resize)
        if v.shape[1] == 1:
            v = v.repeat(1, 3, 1, 1)
        imout[k] = v.detach().cpu()

    return imout


def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.Tensor:
    ba = ims.shape[0]
    H = int(np.ceil(np.sqrt(ba)))
    W = H
    n_add = H * W - ba
    if n_add > 0:
        ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))

    ims = ims.view(H, W, *ims.shape[1:])
    cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
    if size is not None:
        cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
    return cated.clamp(0.0, 1.0)


def _show_predictions(
    preds: List[Dict[str, Any]],
    sequence_name: str,
    viz: "Visdom",
    viz_env: str = "visualizer",
    predicted_keys: Sequence[str] = (
        "images_render",
        "masks_render",
        "depths_render",
        "_all_source_images",
    ),
    n_samples=10,
    one_image_width=200,
) -> None:
    """Given a list of predictions visualize them into a single image using visdom."""
    assert isinstance(preds, list)

    pred_all = []
    # Randomly choose a subset of the rendered images, sort by ordr in the sequence
    n_samples = min(n_samples, len(preds))
    pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
    for predi in pred_idx:
        # Make the concatentation for the same camera vertically
        pred_all.append(
            torch.cat(
                [
                    torch.nn.functional.interpolate(
                        preds[predi][k].cpu(),
                        scale_factor=one_image_width / preds[predi][k].shape[3],
                        mode="bilinear",
                    ).clamp(0.0, 1.0)
                    for k in predicted_keys
                ],
                dim=2,
            )
        )
    # Concatenate the images horizontally
    pred_all_cat = torch.cat(pred_all, dim=3)[0]
    viz.image(
        pred_all_cat,
        win="show_predictions",
        env=viz_env,
        opts={"title": f"pred_{sequence_name}"},
    )


def _generate_prediction_videos(
    preds: List[Dict[str, Any]],
    sequence_name: str,
    viz: Optional["Visdom"] = None,
    viz_env: str = "visualizer",
    predicted_keys: Sequence[str] = (
        "images_render",
        "masks_render",
        "depths_render",
        "_all_source_images",
    ),
    fps: int = 20,
    video_path: str = "/tmp/video",
    video_frames_dir: Optional[str] = None,
    resize: Optional[Tuple[int, int]] = None,
) -> None:
    """Given a list of predictions create and visualize rotating videos of the
    objects using visdom.
    """

    # make sure the target video directory exists
    os.makedirs(os.path.dirname(video_path), exist_ok=True)

    # init a video writer for each predicted key
    vws = {}
    for k in predicted_keys:
        if k not in preds[0]:
            logger.warn(f"Cannot generate video for prediction key '{k}'")
            continue
        cache_dir = (
            None
            if video_frames_dir is None
            else os.path.join(video_frames_dir, f"{sequence_name}_{k}")
        )
        vws[k] = VideoWriter(
            fps=fps,
            out_path=f"{video_path}_{sequence_name}_{k}.mp4",
            cache_dir=cache_dir,
        )

    for rendered_pred in tqdm(preds):
        for k in vws:
            vws[k].write_frame(
                rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
                resize=resize,
            )

    for k in predicted_keys:
        if k not in vws:
            continue
        vws[k].get_video()
        logger.info(f"Generated {vws[k].out_path}.")
        if viz is not None:
            viz.video(
                videofile=vws[k].out_path,
                env=viz_env,
                win=k,  # we reuse the same window otherwise visdom dies
                opts={"title": sequence_name + " " + k},
            )