# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Sequence,
Tuple,
Type,
Union,
)
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def raise_parameter_tying_error() -> NoReturn:
raise RuntimeError(
"make_functional(module): we don't yet support models that "
"do parameter tying (also sometimes known as weight sharing). "
"Please try to rewrite your model by replacing all instances of the "
"tied parameter with another and/or comment your support in "
"https://github.com/pytorch/functorch/issues/446"
)
def create_names_map(
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
) -> Dict[str, List[str]]:
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
with potentially tied (or 'duplicated') tensors
This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
named_params = dict(named_params)
tied_named_params = dict(tied_named_params)
tensors_dict_keys = set(named_params.keys())
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
tensor_to_mapping[tensor][1].append(key)
return dict(tensor_to_mapping.values())
def _extract_members(
mod: nn.Module,
named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
subclass: Callable[[Tensor], Tensor],
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
all_named_members = tuple(named_members(remove_duplicate=False))
unique_named_members = tuple(named_members(remove_duplicate=True))
names_map = create_names_map(unique_named_members, all_named_members)
# Remove all the members in the model
memo = {}
accessor = NamedMemberAccessor(mod)
for name, p in all_named_members:
if p not in memo:
memo[p] = subclass(torch.empty_like(p, device="meta"))
replacement = memo[p]
accessor.set_tensor(name, replacement)
if len(unique_named_members) == 0:
names, params = (), ()
else:
names, params = zip(*unique_named_members) # type: ignore[assignment]
return params, names, names_map
def extract_weights(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
return _extract_members(mod, mod.named_parameters, nn.Parameter)
def extract_buffers(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
return _extract_members(mod, mod.named_buffers, lambda x: x)
def load_weights(
mod: nn.Module,
names: Sequence[str],
params: Sequence[Tensor],
as_params: bool = False,
) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
accessor = NamedMemberAccessor(mod)
if as_params:
params = [nn.Parameter(p) for p in params]
accessor.set_tensors(names, params)
def _swap_state(
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
) -> List[Tensor]:
result: List[Tensor] = []
accessor = NamedMemberAccessor(mod)
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
if i == 0:
result.append(accessor.swap_tensor(attr_name, elem))
else:
accessor.set_tensor(attr_name, elem)
return result
def load_buffers(
mod: nn.Module,
names: Sequence[str],
buffers: Sequence[Tensor],
as_params: bool = False,
) -> None:
accessor = NamedMemberAccessor(mod)
accessor.set_tensors(names, buffers)
def load_state(
model: nn.Module,
weights: Sequence[Tensor],
weight_names: Sequence[str],
buffers: Sequence[Tensor] = (),
buffer_names: Sequence[str] = (),
) -> nn.Module:
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
load_state takes `weights` and `buffers` and assigns them to the model.
This is the inverse operation of `make_functional_deprecated_v1`.
"""
assert len(weight_names) == len(weights)
load_weights(model, weight_names, weights)
if len(buffers) > 0:
assert len(buffer_names) == len(buffers)
load_buffers(model, buffer_names, buffers)
return model
def make_functional_deprecated_v1(model: nn.Module):
"""make_functional_deprecated_v1(model) -> weights, func, weight_names
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
and returns a functional version of the model, `func`. This makes
it so that it is possible use transforms over the parameters of
`model`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, func, _ = make_functional_deprecated_v1(model)
func(weights, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, _, func = make_functional_deprecated_v1(model)
grad_weights = grad(func)(weights, (x,))
```
To put the state back into a model, use `load_state`.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError(
"make_functional_deprecated_v1(model): `model` has buffers. Please use "
"make_functional_with_buffers_deprecated_v1(model) instead."
)
weights, descriptors, _ = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors
def make_functional_with_buffers_deprecated_v1(model: nn.Module):
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
and returns a functional version of the model, `func`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
grad_weights = grad(func)(weights, buffers, (x,))
```
To put the state back into a model, use `load_state`.
"""
weights, weight_descriptors, _ = extract_weights(model)
buffers, buf_descriptors, _ = extract_buffers(model)
def fun(weights, buffers, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, weight_descriptors, weights)
load_buffers(mutable_model, buf_descriptors, buffers)
return mutable_model(*data)
return weights, buffers, fun, weight_descriptors, buf_descriptors
class FunctionalModuleWithBuffers(nn.Module):
"""
This is the callable object returned by :func:`make_functional_with_buffers`.
"""
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
buffer_names: Tuple[str, ...],
param_names_map: Dict[str, List[str]],
buffer_names_map: Dict[str, List[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.buffer_names = buffer_names
self.all_names_map = dict(param_names_map)
self.all_names_map.update(buffer_names_map)
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return (
FunctionalModuleWithBuffers(
model_copy, param_names, buffer_names, param_names_map, buffer_names_map
),
params,
buffers,
)
def forward(
self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs
) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(
self.stateless_model,
self.all_names_map,
tuple(params) + tuple(buffers),
)
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.all_names_map, old_state)
class FunctionalModule(nn.Module):
"""
This is the callable object returned by :func:`make_functional`.
"""
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
names_map: Dict[str, List[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.names_map = names_map
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return FunctionalModule(model_copy, param_names, names_map), params
def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(self.stateless_model, self.names_map, params)
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.names_map, old_state)
Loading ...