Repository URL to install this package:
|
Version:
0.2.4 ▾
|
# Copyright (c) 2019-present, Francesco Croce
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
import time
try:
from torch import flip
except ImportError:
from advertorch.utils import torch_flip as flip
from advertorch.utils import replicate_input
from advertorch.attacks.utils import zero_gradients
from .base import Attack
from .base import LabelMixin
DEFAULT_EPS_DICT_BY_NORM = {'Linf': .3, 'L2': 1., 'L1': 5.0}
class FABAttack(Attack, LabelMixin):
"""
Fast Adaptive Boundary Attack (Linf, L2, L1)
https://arxiv.org/abs/1907.02044
:param predict: forward pass function
:param norm: Lp-norm to minimize ('Linf', 'L2', 'L1' supported)
:param n_restarts: number of random restarts
:param n_iter: number of iterations
:param eps: epsilon for the random restarts
:param alpha_max: alpha_max
:param eta: overshooting
:param beta: backward step
:param device: device to use ('cuda' or 'cpu')
"""
def __init__(
self,
predict,
norm='Linf',
n_restarts=1,
n_iter=100,
eps=None,
alpha_max=0.1,
eta=1.05,
beta=0.9,
loss_fn=None,
verbose=False,
):
""" FAB-attack implementation in pytorch """
super(FABAttack, self).__init__(
predict, loss_fn=None, clip_min=0., clip_max=1.)
self.norm = norm
self.n_restarts = n_restarts
self.n_iter = n_iter
self.eps = eps if eps is not None else DEFAULT_EPS_DICT_BY_NORM[norm]
self.alpha_max = alpha_max
self.eta = eta
self.beta = beta
self.targeted = False
self.verbose = verbose
def check_shape(self, x):
return x if len(x.shape) > 0 else x.unsqueeze(0)
def get_diff_logits_grads_batch(self, imgs, la):
im = imgs.clone().requires_grad_()
with torch.enable_grad():
y = self.predict(im)
g2 = torch.zeros([y.shape[-1], *imgs.size()]).to(self.device)
grad_mask = torch.zeros_like(y)
for counter in range(y.shape[-1]):
zero_gradients(im)
grad_mask[:, counter] = 1.0
y.backward(grad_mask, retain_graph=True)
grad_mask[:, counter] = 0.0
g2[counter] = im.grad.data
g2 = torch.transpose(g2, 0, 1).detach()
y2 = self.predict(imgs).detach()
df = y2 - y2[torch.arange(imgs.shape[0]), la].unsqueeze(1)
dg = g2 - g2[torch.arange(imgs.shape[0]), la].unsqueeze(1)
df[torch.arange(imgs.shape[0]), la] = 1e10
return df, dg
def projection_linf(self, points_to_project, w_hyperplane, b_hyperplane):
t = points_to_project.clone()
w = w_hyperplane.clone()
b = b_hyperplane.clone()
ind2 = ((w * t).sum(1) - b < 0).nonzero().squeeze()
ind2 = self.check_shape(ind2)
w[ind2] *= -1
b[ind2] *= -1
c5 = (w < 0).float()
a = torch.ones(t.shape).to(self.device)
d = (a * c5 - t) * (w != 0).float()
a -= a * (1 - c5)
p = torch.ones(t.shape).to(self.device) * c5 - t * (2 * c5 - 1)
_, indp = torch.sort(p, dim=1)
b = b - (w * t).sum(1)
b0 = (w * d).sum(1)
b1 = b0.clone()
counter = 0
indp2 = flip(indp.unsqueeze(-1), dims=(1, 2)).squeeze()
u = torch.arange(0, w.shape[0])
ws = w[u.unsqueeze(1), indp2]
bs2 = - ws * d[u.unsqueeze(1), indp2]
s = torch.cumsum(ws.abs(), dim=1)
sb = torch.cumsum(bs2, dim=1) + b0.unsqueeze(1)
c = b - b1 > 0
b2 = sb[u, -1] - s[u, -1] * p[u, indp[u, 0]]
c_l = (b - b2 > 0).nonzero().squeeze()
c2 = ((b - b1 > 0) * (b - b2 <= 0)).nonzero().squeeze()
c_l = self.check_shape(c_l)
c2 = self.check_shape(c2)
lb = torch.zeros(c2.shape[0])
ub = torch.ones(c2.shape[0]) * (w.shape[1] - 1)
nitermax = torch.ceil(torch.log2(torch.tensor(w.shape[1]).float()))
counter2 = torch.zeros(lb.shape).long()
while counter < nitermax:
counter4 = torch.floor((lb + ub) / 2)
counter2 = counter4.long()
indcurr = indp[c2, -counter2 - 1]
b2 = sb[c2, counter2] - s[c2, counter2] * p[c2, indcurr]
c = b[c2] - b2 > 0
ind3 = c.nonzero().squeeze()
ind32 = (~c).nonzero().squeeze()
ind3 = self.check_shape(ind3)
ind32 = self.check_shape(ind32)
lb[ind3] = counter4[ind3]
ub[ind32] = counter4[ind32]
counter += 1
lb = lb.long()
counter2 = 0
if c_l.nelement() != 0:
lmbd_opt = (torch.max((b[c_l] - sb[c_l, -1]) / (-s[c_l, -1]),
torch.zeros(sb[c_l, -1].shape)
.to(self.device))).unsqueeze(-1)
d[c_l] = (2 * a[c_l] - 1) * lmbd_opt
lmbd_opt = (torch.max((b[c2] - sb[c2, lb]) / (-s[c2, lb]),
torch.zeros(sb[c2, lb].shape)
.to(self.device))).unsqueeze(-1)
d[c2] = torch.min(lmbd_opt, d[c2]) * c5[c2]\
+ torch.max(-lmbd_opt, d[c2]) * (1 - c5[c2])
return d * (w != 0).float()
def projection_l2(self, points_to_project, w_hyperplane, b_hyperplane):
t = points_to_project.clone()
w = w_hyperplane.clone()
b = b_hyperplane.clone()
c = (w * t).sum(1) - b
ind2 = (c < 0).nonzero().squeeze()
ind2 = self.check_shape(ind2)
w[ind2] *= -1
c[ind2] *= -1
u = torch.arange(0, w.shape[0]).unsqueeze(1)
r = torch.max(t / w, (t - 1) / w)
u2 = torch.ones(r.shape).to(self.device)
r = torch.min(r, 1e12 * u2)
r = torch.max(r, -1e12 * u2)
r[w.abs() < 1e-8] = 1e12
r[r == -1e12] = -r[r == -1e12]
rs, indr = torch.sort(r, dim=1)
rs2 = torch.cat((rs[:, 1:],
torch.zeros(rs.shape[0], 1).to(self.device)), 1)
rs[rs == 1e12] = 0
rs2[rs2 == 1e12] = 0
w3 = w ** 2
w3s = w3[u, indr]
w5 = w3s.sum(dim=1, keepdim=True)
ws = w5 - torch.cumsum(w3s, dim=1)
d = -(r * w).clone()
d = d * (w.abs() > 1e-8).float()
s = torch.cat(((-w5.squeeze() * rs[:, 0]).unsqueeze(1),
torch.cumsum((-rs2 + rs) * ws, dim=1) -
w5 * rs[:, 0].unsqueeze(-1)), 1)
c4 = (s[:, 0] + c < 0)
c3 = ((d * w).sum(dim=1) + c > 0)
c6 = c4.nonzero().squeeze()
c2 = ((1 - c4.float()) * (1 - c3.float())).nonzero().squeeze()
c6 = self.check_shape(c6)
c2 = self.check_shape(c2)
counter = 0
lb = torch.zeros(c2.shape[0])
ub = torch.ones(c2.shape[0]) * (w.shape[1] - 1)
nitermax = torch.ceil(torch.log2(torch.tensor(w.shape[1]).float()))
counter2 = torch.zeros(lb.shape).long()
while counter < nitermax:
counter4 = torch.floor((lb + ub) / 2)
counter2 = counter4.long()
c3 = s[c2, counter2] + c[c2] > 0
ind3 = c3.nonzero().squeeze()
ind32 = (~c3).nonzero().squeeze()
ind3 = self.check_shape(ind3)
ind32 = self.check_shape(ind32)
lb[ind3] = counter4[ind3]
ub[ind32] = counter4[ind32]
counter += 1
lb = lb.long()
alpha = torch.zeros([1])
if c6.nelement() != 0:
alpha = c[c6] / w5[c6].squeeze(-1)
d[c6] = -alpha.unsqueeze(-1) * w[c6]
if c2.nelement() != 0:
alpha = (s[c2, lb] + c[c2]) / ws[c2, lb] + rs[c2, lb]
if torch.sum(ws[c2, lb] == 0) > 0:
ind = (ws[c2, lb] == 0).nonzero().squeeze().long()
ind = self.check_shape(ind)
alpha[ind] = 0
c5 = (alpha.unsqueeze(-1) > r[c2]).float()
d[c2] = d[c2] * c5 - alpha.unsqueeze(-1) * w[c2] * (1 - c5)
return d * (w.abs() > 1e-8).float()
def projection_l1(self, points_to_project, w_hyperplane, b_hyperplane):
t = points_to_project.clone()
w = w_hyperplane.clone()
b = b_hyperplane.clone()
c = (w * t).sum(1) - b
ind2 = (c < 0).nonzero().squeeze()
ind2 = self.check_shape(ind2)
w[ind2] *= -1
c[ind2] *= -1
r = torch.max(1 / w, -1 / w)
r = torch.min(r, 1e12 * torch.ones(r.shape).to(self.device))
rs, indr = torch.sort(r, dim=1)
_, indr_rev = torch.sort(indr)
u = torch.arange(0, w.shape[0]).unsqueeze(1)
u2 = torch.arange(0, w.shape[1]).repeat(w.shape[0], 1)
c6 = (w < 0).float()
d = (-t + c6) * (w != 0).float()
d2 = torch.min(-w * t, w * (1 - t))
ds = d2[u, indr]
ds2 = torch.cat((c.unsqueeze(-1), ds), 1)
s = torch.cumsum(ds2, dim=1)
c4 = s[:, -1] < 0
c2 = c4.nonzero().squeeze(-1)
c2 = self.check_shape(c2)
counter = 0
lb = torch.zeros(c2.shape[0])
ub = torch.ones(c2.shape[0]) * (s.shape[1])
nitermax = torch.ceil(torch.log2(torch.tensor(s.shape[1]).float()))
counter2 = torch.zeros(lb.shape).long()
while counter < nitermax:
counter4 = torch.floor((lb + ub) / 2)
counter2 = counter4.long()
c3 = s[c2, counter2] > 0
ind3 = c3.nonzero().squeeze()
ind32 = (~c3).nonzero().squeeze()
ind3 = self.check_shape(ind3)
ind32 = self.check_shape(ind32)
lb[ind3] = counter4[ind3]
ub[ind32] = counter4[ind32]
counter += 1
lb2 = lb.long()
if c2.nelement() != 0:
alpha = -s[c2, lb2] / w[c2, indr[c2, lb2]]
c5 = u2[c2].float() < lb.unsqueeze(-1).float()
u3 = c5[u[:c5.shape[0]], indr_rev[c2]]
d[c2] = d[c2] * u3.float().to(self.device)
d[c2, indr[c2, lb2]] = alpha
return d * (w.abs() > 1e-8).float()
def perturb(self, x, y=None):
"""
:param x: clean images
:param y: clean labels, if None we use the predicted labels
"""
self.device = x.device
self.orig_dim = list(x.shape[1:])
self.ndims = len(self.orig_dim)
x = x.detach().clone().float().to(self.device)
# assert next(self.predict.parameters()).device == x.device
y_pred = self._get_predicted_label(x)
if y is None:
y = y_pred.detach().clone().long().to(self.device)
else:
y = y.detach().clone().long().to(self.device)
pred = y_pred == y
corr_classified = pred.float().sum()
if self.verbose:
print('Clean accuracy: {:.2%}'.format(pred.float().mean()))
if pred.sum() == 0:
return x
pred = self.check_shape(pred.nonzero().squeeze())
startt = time.time()
# runs the attack only on correctly classified points
im2 = replicate_input(x[pred])
la2 = replicate_input(y[pred])
if len(im2.shape) == self.ndims:
im2 = im2.unsqueeze(0)
bs = im2.shape[0]
u1 = torch.arange(bs)
adv = im2.clone()
adv_c = x.clone()
res2 = 1e10 * torch.ones([bs]).to(self.device)
res_c = torch.zeros([x.shape[0]]).to(self.device)
x1 = im2.clone()
x0 = im2.clone().reshape([bs, -1])
counter_restarts = 0
while counter_restarts < self.n_restarts:
if counter_restarts > 0:
if self.norm == 'Linf':
t = 2 * torch.rand(x1.shape).to(self.device) - 1
x1 = im2 + (
torch.min(
res2,
self.eps * torch.ones(res2.shape).to(self.device)
).reshape([-1, *([1] * self.ndims)])
) * t / (t.reshape([t.shape[0], -1]).abs()
.max(dim=1, keepdim=True)[0]
.reshape([-1, *([1] * self.ndims)])) * .5
elif self.norm == 'L2':
t = torch.randn(x1.shape).to(self.device)
x1 = im2 + (
torch.min(
res2,
self.eps * torch.ones(res2.shape).to(self.device)
).reshape([-1, *([1] * self.ndims)])
) * t / ((t ** 2)
.view(t.shape[0], -1)
.sum(dim=-1)
.sqrt()
.view(t.shape[0], *([1] * self.ndims))) * .5
elif self.norm == 'L1':
t = torch.randn(x1.shape).to(self.device)
x1 = im2 + (torch.min(
res2,
self.eps * torch.ones(res2.shape).to(self.device)
).reshape([-1, *([1] * self.ndims)])
) * t / (t.abs().view(t.shape[0], -1)
.sum(dim=-1)
.view(t.shape[0], *([1] * self.ndims))) / 2
x1 = x1.clamp(0.0, 1.0)
counter_iter = 0
while counter_iter < self.n_iter:
with torch.no_grad():
df, dg = self.get_diff_logits_grads_batch(x1, la2)
if self.norm == 'Linf':
dist1 = df.abs() / (1e-12 +
dg.abs()
.view(dg.shape[0], dg.shape[1], -1)
.sum(dim=-1))
elif self.norm == 'L2':
dist1 = df.abs() / (1e-12 + (dg ** 2)
.view(dg.shape[0], dg.shape[1], -1)
.sum(dim=-1).sqrt())
elif self.norm == 'L1':
dist1 = df.abs() / (1e-12 + dg.abs().reshape(
[df.shape[0], df.shape[1], -1]).max(dim=2)[0])
else:
raise ValueError('norm not supported')
ind = dist1.min(dim=1)[1]
dg2 = dg[u1, ind]
b = (- df[u1, ind] +
(dg2 * x1).view(x1.shape[0], -1).sum(dim=-1))
w = dg2.reshape([bs, -1])
if self.norm == 'Linf':
d3 = self.projection_linf(
torch.cat((x1.reshape([bs, -1]), x0), 0),
torch.cat((w, w), 0),
torch.cat((b, b), 0))
elif self.norm == 'L2':
d3 = self.projection_l2(
torch.cat((x1.reshape([bs, -1]), x0), 0),
torch.cat((w, w), 0),
torch.cat((b, b), 0))
elif self.norm == 'L1':
d3 = self.projection_l1(
torch.cat((x1.reshape([bs, -1]), x0), 0),
torch.cat((w, w), 0),
torch.cat((b, b), 0))
d1 = torch.reshape(d3[:bs], x1.shape)
d2 = torch.reshape(d3[-bs:], x1.shape)
if self.norm == 'Linf':
a0 = d3.abs().max(dim=1, keepdim=True)[0]\
.view(-1, *([1] * self.ndims))
elif self.norm == 'L2':
a0 = (d3 ** 2).sum(dim=1, keepdim=True).sqrt()\
.view(-1, *([1] * self.ndims))
elif self.norm == 'L1':
a0 = d3.abs().sum(dim=1, keepdim=True)\
.view(-1, *([1] * self.ndims))
a0 = torch.max(a0, 1e-8 * torch.ones(
a0.shape).to(self.device))
a1 = a0[:bs]
a2 = a0[-bs:]
alpha = torch.min(torch.max(a1 / (a1 + a2),
torch.zeros(a1.shape)
.to(self.device))[0],
self.alpha_max * torch.ones(a1.shape)
.to(self.device))
x1 = ((x1 + self.eta * d1) * (1 - alpha) +
(im2 + d2 * self.eta) * alpha).clamp(0.0, 1.0)
is_adv = self._get_predicted_label(x1) != la2
if is_adv.sum() > 0:
ind_adv = is_adv.nonzero().squeeze()
ind_adv = self.check_shape(ind_adv)
if self.norm == 'Linf':
t = (x1[ind_adv] - im2[ind_adv]).reshape(
[ind_adv.shape[0], -1]).abs().max(dim=1)[0]
elif self.norm == 'L2':
t = ((x1[ind_adv] - im2[ind_adv]) ** 2)\
.view(ind_adv.shape[0], -1).sum(dim=-1).sqrt()
elif self.norm == 'L1':
t = (x1[ind_adv] - im2[ind_adv])\
.abs().view(ind_adv.shape[0], -1).sum(dim=-1)
adv[ind_adv] = x1[ind_adv] * (t < res2[ind_adv]).\
float().reshape([-1, *([1] * self.ndims)]) \
+ adv[ind_adv]\
* (t >= res2[ind_adv]).float().reshape(
[-1, *([1] * self.ndims)])
res2[ind_adv] = t * (t < res2[ind_adv]).float()\
+ res2[ind_adv] * (t >= res2[ind_adv]).float()
x1[ind_adv] = im2[ind_adv] + (
x1[ind_adv] - im2[ind_adv]) * self.beta
counter_iter += 1
counter_restarts += 1
ind_succ = res2 < 1e10
if self.verbose:
print('success rate: {:.0f}/{:.0f}'
.format(ind_succ.float().sum(), corr_classified) +
' (on correctly classified points) in {:.1f} s'
.format(time.time() - startt))
res_c[pred] = res2 * ind_succ.float() + 1e10 * (1 - ind_succ.float())
ind_succ = self.check_shape(ind_succ.nonzero().squeeze())
adv_c[pred[ind_succ]] = adv[ind_succ].clone()
return adv_c
class LinfFABAttack(FABAttack):
"""
Linf - Fast Adaptive Boundary Attack
https://arxiv.org/abs/1907.02044
:param predict: forward pass function
:param n_restarts: number of random restarts
:param n_iter: number of iterations
:param eps: epsilon for the random restarts
:param alpha_max: alpha_max
:param eta: overshooting
:param beta: backward step
:param device: device to use ('cuda' or 'cpu')
"""
def __init__(
self,
predict,
n_restarts=1,
n_iter=100,
eps=None,
alpha_max=0.1,
eta=1.05,
beta=0.9,
loss_fn=None,
verbose=False,
):
norm = 'Linf'
super(LinfFABAttack, self).__init__(
predict=predict, norm=norm, n_restarts=n_restarts,
n_iter=n_iter, eps=eps, alpha_max=alpha_max, eta=eta, beta=beta,
loss_fn=loss_fn, verbose=verbose)
class L2FABAttack(FABAttack):
"""
L2 - Fast Adaptive Boundary Attack
https://arxiv.org/abs/1907.02044
:param predict: forward pass function
:param n_restarts: number of random restarts
:param n_iter: number of iterations
:param eps: epsilon for the random restarts
:param alpha_max: alpha_max
:param eta: overshooting
:param beta: backward step
:param device: device to use ('cuda' or 'cpu')
"""
def __init__(
self,
predict,
n_restarts=1,
n_iter=100,
eps=None,
alpha_max=0.1,
eta=1.05,
beta=0.9,
loss_fn=None,
verbose=False,
):
norm = 'L2'
super(L2FABAttack, self).__init__(
predict=predict, norm=norm, n_restarts=n_restarts,
n_iter=n_iter, eps=eps, alpha_max=alpha_max, eta=eta, beta=beta,
loss_fn=loss_fn, verbose=verbose)
class L1FABAttack(FABAttack):
"""
L1 - Fast Adaptive Boundary Attack
https://arxiv.org/abs/1907.02044
:param predict: forward pass function
:param n_restarts: number of random restarts
:param n_iter: number of iterations
:param eps: epsilon for the random restarts
:param alpha_max: alpha_max
:param eta: overshooting
:param beta: backward step
:param device: device to use ('cuda' or 'cpu')
"""
def __init__(
self,
predict,
n_restarts=1,
n_iter=100,
eps=None,
alpha_max=0.1,
eta=1.05,
beta=0.9,
loss_fn=None,
verbose=False,
):
norm = 'L1'
super(L1FABAttack, self).__init__(
predict=predict, norm=norm, n_restarts=n_restarts,
n_iter=n_iter, eps=eps, alpha_max=alpha_max, eta=eta, beta=beta,
loss_fn=loss_fn, verbose=verbose)