Repository URL to install this package:
|
Version:
0.2.4 ▾
|
advertorch
/
bpda.py
|
|---|
# Copyright (c) 2018-present, Royal Bank of Canada and other authors.
# See the AUTHORS.txt file for a list of contributors.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# BPDA stands for Backward Pass Differentiable Approximation
# See:
# Athalye, A., Carlini, N. & Wagner, D.. (2018). Obfuscated Gradients Give a
# False Sense of Security: Circumventing Defenses to Adversarial Examples.
# Proceedings of the 35th International Conference on Machine Learning,
# in PMLR 80:274-283
import torch
import torch.nn as nn
__all__ = ['BPDAWrapper']
class FunctionWrapper(nn.Module):
"""`nn.Module` wrapping a `torch.autograd.Function`."""
def __init__(self, func):
"""Wraps the provided function `func`.
:param func: the `torch.autograd.Function` to be wrapped.
"""
super(FunctionWrapper, self).__init__()
self.func = func
def forward(self, *inputs):
"""Wraps the `forward` method of `func`."""
return self.func.apply(*inputs)
class BPDAWrapper(FunctionWrapper):
"""Backward Pass Differentiable Approximation.
The module should be provided a `forward` method and a `backward`
method that approximates the derivatives of `forward`.
The `forward` function is called in the forward pass, and the
`backward` function is used to find gradients in the backward pass.
The `backward` function can be implicitly provided-by providing
`forwardsub` - an alternative forward pass function, which its
gradient will be used in the backward pass.
If not `backward` nor `forwardsub` are provided, the `backward`
function will be assumed to be the identity.
:param forward: `forward(*inputs)` - the forward function for BPDA.
:param forwardsub: (Optional) a substitute forward function, for the
gradients approximation of `forward`.
:param backward: (Optional) `backward(inputs, grad_outputs)` the
backward pass function for BPDA.
"""
def __init__(self, forward, forwardsub=None, backward=None):
func = self._create_func(forward, backward, forwardsub)
super(BPDAWrapper, self).__init__(func)
@classmethod
def _create_func(cls, forward_fn, backward_fn, forwardsub_fn):
if backward_fn is not None:
return cls._create_func_backward(forward_fn, backward_fn)
if forwardsub_fn is not None:
return cls._create_func_forwardsub(forward_fn, forwardsub_fn)
return cls._create_func_forward_only(forward_fn)
@classmethod
def _create_func_forward_only(cls, forward_fn):
"""Creates a differentiable `Function` given the forward function,
and the identity as backward function."""
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs, **kwargs):
ctx.save_for_backward(*inputs)
return forward_fn(*inputs, **kwargs)
@staticmethod
def backward(ctx, *grad_outputs):
inputs = ctx.saved_tensors
if len(grad_outputs) == len(inputs):
return grad_outputs
elif len(grad_outputs) == 1:
return tuple([grad_outputs[0] for _ in inputs])
raise ValueError("Expected %d gradients but got %d" %
(len(inputs), len(grad_outputs)))
return Func
@classmethod
def _create_func_forwardsub(cls, forward_fn, forwardsub_fn):
"""Creates a differentiable `Function` given the forward function,
and a substitute forward function.
The substitute forward function is used to approximate the gradients
in the backward pass.
"""
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs, **kwargs):
ctx.save_for_backward(*inputs)
return forward_fn(*inputs, **kwargs)
@staticmethod
@torch.enable_grad() # enables grad in the method's scope
def backward(ctx, *grad_outputs):
inputs = ctx.saved_tensors
inputs = [x.detach().clone().requires_grad_() for x in inputs]
outputs = forwardsub_fn(*inputs)
return torch.autograd.grad(outputs, inputs, grad_outputs)
return Func
@classmethod
def _create_func_backward(cls, forward_fn, backward_fn):
"""Creates a differentiable `Function` given the forward and backward
functions."""
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs, **kwargs):
ctx.save_for_backward(*inputs)
return forward_fn(*inputs, **kwargs)
@staticmethod
def backward(ctx, *grad_outputs):
inputs = ctx.saved_tensors
return backward_fn(inputs, grad_outputs)
return Func