Repository URL to install this package:
|
Version:
0.2.4 ▾
|
# Copyright (c) 2018-present, Royal Bank of Canada and other authors.
# See the AUTHORS.txt file for a list of contributors.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import collections
import numpy as np
import torch
from torch.distributions import laplace
from torch.distributions import uniform
from torch.nn.modules.loss import _Loss
from advertorch.utils import clamp
from advertorch.utils import clamp_by_pnorm
from advertorch.utils import batch_multiply
from advertorch.utils import normalize_by_pnorm
from advertorch.utils import predict_from_logits
from advertorch.loss import ZeroOneLoss
from advertorch.attacks import Attack, LabelMixin
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)
def rand_init_delta(delta, x, ord, eps, clip_min, clip_max):
# TODO: Currently only considered one way of "uniform" sampling
# for Linf, there are 3 ways:
# 1) true uniform sampling by first calculate the rectangle then sample
# 2) uniform in eps box then truncate using data domain (implemented)
# 3) uniform sample in data domain then truncate with eps box
# for L2, true uniform sampling is hard, since it requires uniform sampling
# inside a intersection of cube and ball, so there are 2 ways:
# 1) uniform sample in the data domain, then truncate using the L2 ball
# (implemented)
# 2) uniform sample in the L2 ball, then truncate using the data domain
# for L1: uniform l1 ball init, then truncate using the data domain
if isinstance(eps, torch.Tensor):
assert len(eps) == len(delta)
if ord == np.inf:
delta.data.uniform_(-1, 1)
delta.data = batch_multiply(eps, delta.data)
elif ord == 2:
delta.data.uniform_(clip_min, clip_max)
delta.data = delta.data - x
delta.data = clamp_by_pnorm(delta.data, ord, eps)
elif ord == 1:
ini = laplace.Laplace(
loc=delta.new_tensor(0), scale=delta.new_tensor(1)
)
delta.data = ini.sample(delta.data.shape)
delta.data = normalize_by_pnorm(delta.data, p=1)
ray = uniform.Uniform(0, eps).sample()
delta.data *= ray
delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data
else:
error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
raise NotImplementedError(error)
delta.data = clamp(x + delta.data, min=clip_min, max=clip_max) - x
return delta.data
def is_successful(y1, y2, targeted):
if targeted is True:
return y1 == y2
else:
return y1 != y2
class AttackConfig(object):
# a convenient class for generate an attack/adversary instance
def __init__(self):
self.kwargs = {}
for mro in reversed(self.__class__.__mro__):
if mro in (AttackConfig, object):
continue
for kwarg in mro.__dict__:
if kwarg in self.AttackClass.__init__.__code__.co_varnames:
self.kwargs[kwarg] = mro.__dict__[kwarg]
else:
# make sure we don't specify wrong kwargs
assert kwarg in ["__module__", "AttackClass", "__doc__"]
def __call__(self, *args):
adversary = self.AttackClass(*args, **self.kwargs)
print(self.AttackClass, args, self.kwargs)
return adversary
def multiple_mini_batch_attack(
adversary, loader, device="cuda", save_adv=False, norm=None, num_batch=None
):
lst_label = []
lst_pred = []
lst_advpred = []
lst_dist = []
_norm_convert_dict = {"Linf": "inf", "L2": 2, "L1": 1}
if norm in _norm_convert_dict:
norm = _norm_convert_dict[norm]
if norm == "inf":
def dist_func(x, y):
return (x - y).view(x.size(0), -1).max(dim=1)[0]
elif norm == 1 or norm == 2:
from advertorch.utils import _get_norm_batch
def dist_func(x, y):
return _get_norm_batch(x - y, norm)
else:
assert norm is None
idx_batch = 0
for data, label in loader:
data, label = data.to(device), label.to(device)
adv = adversary.perturb(data, label)
advpred = predict_from_logits(adversary.predict(adv))
pred = predict_from_logits(adversary.predict(data))
lst_label.append(label)
lst_pred.append(pred)
lst_advpred.append(advpred)
if norm is not None:
lst_dist.append(dist_func(data, adv))
idx_batch += 1
if idx_batch == num_batch:
break
return (
torch.cat(lst_label),
torch.cat(lst_pred),
torch.cat(lst_advpred),
torch.cat(lst_dist) if norm is not None else None,
)
class MarginalLoss(_Loss):
# TODO: move this to advertorch.loss
def forward(self, logits, targets): # pylint: disable=arguments-differ
assert logits.shape[-1] >= 2
top_logits, top_classes = torch.topk(logits, 2, dim=-1)
target_logits = logits[torch.arange(logits.shape[0]), targets]
max_nontarget_logits = torch.where(
top_classes[..., 0] == targets,
top_logits[..., 1],
top_logits[..., 0],
)
loss = max_nontarget_logits - target_logits
if self.reduction == "none":
pass
elif self.reduction == "sum":
loss = loss.sum()
elif self.reduction == "mean":
loss = loss.mean()
else:
raise ValueError("unknown reduction: '%s'" % (self.recution,))
return loss
class ChooseBestAttack(Attack, LabelMixin):
def __init__(
self, predict, base_adversaries, loss_fn=None, targeted=False
):
self.predict = predict
self.base_adversaries = base_adversaries
self.loss_fn = loss_fn
self.targeted = targeted
if self.loss_fn is None:
self.loss_fn = ZeroOneLoss(reduction="none")
else:
assert self.loss_fn.reduction == "none"
for adversary in self.base_adversaries:
assert self.targeted == adversary.targeted
def perturb(self, x, y=None):
# TODO: might want to also retain the list of all attacks
x, y = self._verify_and_process_inputs(x, y)
with torch.no_grad():
maxloss = self.loss_fn(self.predict(x), y)
final_adv = torch.zeros_like(x)
for adversary in self.base_adversaries:
adv = adversary.perturb(x, y)
loss = self.loss_fn(self.predict(adv), y)
to_replace = maxloss < loss
final_adv[to_replace] = adv[to_replace]
maxloss[to_replace] = loss[to_replace]
return final_adv
def attack_whole_dataset(adversary, loader, device="cuda"):
lst_adv = []
lst_label = []
lst_pred = []
lst_advpred = []
for data, label in loader:
data, label = data.to(device), label.to(device)
pred = predict_from_logits(adversary.predict(data))
adv = adversary.perturb(data, label)
advpred = predict_from_logits(adversary.predict(adv))
lst_label.append(label)
lst_pred.append(pred)
lst_advpred.append(advpred)
lst_adv.append(adv)
return (
torch.cat(lst_adv),
torch.cat(lst_label),
torch.cat(lst_pred),
torch.cat(lst_advpred),
)