Repository URL to install this package:
|
Version:
0.2.4 ▾
|
# Copyright (c) 2020-present, Pouya Bashivan.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
"""
This code is an adaptation of DeepFool implementation from foolbox package
https://github.com/bethgelab/foolbox
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from advertorch.utils import batch_clamp, clamp
from advertorch.utils import replicate_input, replicate_input_withgrad
import torch as torch
from .base import Attack, LabelMixin
class DeepfoolLinfAttack(Attack, LabelMixin):
"""
A simple and fast gradient-based adversarial attack.
Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Pascal Frossard,
"DeepFool: a simple and accurate method to fool deep neural
networks", https://arxiv.org/abs/1511.04599
:param predict: forward pass function.
:param num_classes: number of classes considered
:param nb_iter: number of iterations.
:eps=0.1
:param clip_min: mininum value per input dimension.
:param clip_max: maximum value per input dimension.
"""
def __init__(self, predict, num_classes=None, nb_iter=50, eps=0.1,
overshoot=0.02, clip_min=0., clip_max=1., loss_fn=None,
targeted=False):
"""
Deepfool Linf Attack implementation in pytorch.
"""
super(DeepfoolLinfAttack, self).__init__(predict, loss_fn=loss_fn,
clip_min=clip_min,
clip_max=clip_max)
self.predict = predict
self.num_classes = num_classes
self.nb_iter = nb_iter
self.eps = eps
self.overshoot = overshoot
self.targeted = targeted
def is_adv(self, logits, y):
# criterion
y_hat = logits.argmax(-1)
is_adv = y_hat != y
return is_adv
def get_deltas_logits(self, x, k, classes):
# definition of loss_fn
N = len(classes)
rows = range(N)
i0 = classes[:, 0]
logits = self.predict(x)
ik = classes[:, k]
l0 = logits[rows, i0]
lk = logits[rows, ik]
delta_logits = lk - l0
return {'sum_deltas': delta_logits.sum(),
'deltas': delta_logits,
'logits': logits}
def get_grads(self, x, k, classes):
deltas_logits = self.get_deltas_logits(x, k, classes)
deltas_logits['sum_deltas'].backward()
deltas_logits['grads'] = x.grad.clone()
x.grad.data.zero_()
return deltas_logits
def get_distances(self, deltas, grads):
return abs(deltas) / (
grads.flatten(start_dim=2, end_dim=-1).abs().sum(axis=-1) + 1e-8)
def get_perturbations(self, distances, grads):
return self.atleast_kd(distances, grads.ndim) * grads.sign()
def atleast_kd(self, x, k):
shape = x.shape + (1,) * (k - x.ndim)
return x.reshape(shape)
def _verify_and_process_inputs(self, x, y):
if self.targeted:
assert y is not None
if not self.targeted:
if y is None:
y = self._get_predicted_label(x)
x = replicate_input_withgrad(x)
y = replicate_input(y)
return x, y
def perturb(self, x, y=None):
"""
Given examples (x, y), returns their adversarial counterparts with
an attack length of eps.
:param x: input tensor.
:param y: label tensor.
:return: tensor containing perturbed inputs.
"""
x, y = self._verify_and_process_inputs(x, y)
x.requires_grad_()
logits = self.predict(x)
# get the classes
classes = logits.argsort(axis=-1).flip(-1).detach()
if self.num_classes is None:
self.num_classes = logits.shape[-1]
else:
self.num_classes = min(self.num_classes, logits.shape[-1])
N = len(x)
rows = range(N)
x0 = x
p_total = torch.zeros_like(x)
for _ in range(self.nb_iter):
# let's first get the logits using k = 1 to see if we are done
diffs = [self.get_grads(x, 1, classes)]
is_adv = self.is_adv(diffs[0]['logits'], y)
if is_adv.all():
break
diffs += [self.get_grads(x, k, classes) for k in range(2, self.num_classes)] # noqa
deltas = torch.stack([d['deltas'] for d in diffs], dim=-1)
grads = torch.stack([d['grads'] for d in diffs], dim=1)
assert deltas.shape == (N, self.num_classes - 1)
assert grads.shape == (N, self.num_classes - 1) + x0.shape[1:]
# calculate the distances
# compute f_k / ||w_k||
distances = self.get_distances(deltas, grads)
assert distances.shape == (N, self.num_classes - 1)
# determine the best directions
best = distances.argmin(axis=1) # compute \hat{l}
distances = distances[rows, best]
deltas = deltas[rows, best]
grads = grads[rows, best]
assert distances.shape == (N,)
assert deltas.shape == (N,)
assert grads.shape == x0.shape
# apply perturbation
distances = distances + 1e-4 # for numerical stability
p_step = self.get_perturbations(distances, grads) # =r_i
assert p_step.shape == x0.shape
p_total += p_step
p_total = batch_clamp(self.eps, p_total)
# don't do anything for those that are already adversarial
x = torch.where(
self.atleast_kd(is_adv, x.ndim),
x,
x0 + (1.0 + self.overshoot) * p_total,
) # =x_{i+1}
x = clamp(x, min=self.clip_min, max=self.clip_max).clone().detach().requires_grad_() # noqa
return x.detach()