Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ fx / passes / shape_prop.py

import torch
import torch.fx
from torch.fx.node import Node
from typing import Any

class ShapeProp(torch.fx.Interpreter):
    def run_node(self, n : Node) -> Any:
        result = super().run_node(n)

        if isinstance(result, torch.Tensor):
            n.shape = result.shape  # type: ignore
            n.dtype = result.dtype  # type: ignore

        return result

    def propagate(self, *args):
        return super().run(*args)