Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch3d / implicitron / models / global_encoder / global_encoder.py
Size: Mime:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import List, Optional, Union

import torch
from pytorch3d.implicitron.tools.config import (
    registry,
    ReplaceableBase,
    run_auto_creation,
)
from pytorch3d.renderer.implicit import HarmonicEmbedding

from .autodecoder import Autodecoder


class GlobalEncoderBase(ReplaceableBase):
    """
    A base class for implementing encoders of global frame-specific quantities.

    The latter includes e.g. the harmonic encoding of a frame timestamp
    (`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
    (`SequenceAutodecoder`).
    """

    def get_encoding_dim(self):
        """
        Returns the dimensionality of the returned encoding.
        """
        raise NotImplementedError()

    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
        """
        Calculates the squared norm of the encoding to report as the
        `autodecoder_norm` loss of the model, as a zero dimensional tensor.
        """
        raise NotImplementedError()

    def forward(
        self,
        *,
        frame_timestamp: Optional[torch.Tensor] = None,
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
        **kwargs,
    ) -> torch.Tensor:
        """
        Given a set of inputs to encode, generates a tensor containing the encoding.

        Returns:
            encoding: The tensor containing the global encoding.
        """
        raise NotImplementedError()


# TODO: probabilistic embeddings?
@registry.register
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):
    """
    A global encoder implementation which provides an autodecoder encoding
    of the frame's sequence identifier.
    """

    # pyre-fixme[13]: Attribute `autodecoder` is never initialized.
    autodecoder: Autodecoder

    def __post_init__(self):
        run_auto_creation(self)

    def get_encoding_dim(self):
        return self.autodecoder.get_encoding_dim()

    def forward(
        self,
        *,
        frame_timestamp: Optional[torch.Tensor] = None,
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
        **kwargs,
    ) -> torch.Tensor:
        if sequence_name is None:
            raise ValueError("sequence_name must be provided.")
        # run dtype checks and pass sequence_name to self.autodecoder
        return self.autodecoder(sequence_name)

    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
        return self.autodecoder.calculate_squared_encoding_norm()


@registry.register
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
    """
    A global encoder implementation which provides harmonic embeddings
    of each frame's timestamp.
    """

    n_harmonic_functions: int = 10
    append_input: bool = True
    time_divisor: float = 1.0

    def __post_init__(self):
        self._harmonic_embedding = HarmonicEmbedding(
            n_harmonic_functions=self.n_harmonic_functions,
            append_input=self.append_input,
        )

    def get_encoding_dim(self):
        return self._harmonic_embedding.get_output_dim(1)

    def forward(
        self,
        *,
        frame_timestamp: Optional[torch.Tensor] = None,
        sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
        **kwargs,
    ) -> torch.Tensor:
        if frame_timestamp is None:
            raise ValueError("frame_timestamp must be provided.")
        if frame_timestamp.shape[-1] != 1:
            raise ValueError("Frame timestamp's last dimensions should be one.")
        time = frame_timestamp / self.time_divisor
        # pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
        return self._harmonic_embedding(time)

    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
        return None