Repository URL to install this package:
|
Version:
0.2.4 ▾
|
# Copyright (c) 2018-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import torch
import torch.nn.functional as F
from advertorch.utils import calc_l2distsq
from .base import Attack
from .base import LabelMixin
L2DIST_UPPER = 1e10
COEFF_UPPER = 1e10
INVALID_LABEL = -1
UPPER_CHECK = 1e9
class LBFGSAttack(Attack, LabelMixin):
"""
The attack that uses L-BFGS to minimize the distance of the original
and perturbed images
:param predict: forward pass function.
:param num_classes: number of clasess.
:param batch_size: number of samples in the batch
:param binary_search_steps: number of binary search times to find the
optimum
:param max_iterations: the maximum number of iterations
:param initial_const: initial value of the constant c
:param clip_min: mininum value per input dimension.
:param clip_max: maximum value per input dimension.
:param loss_fn: loss function
:param targeted: if the attack is targeted.
"""
def __init__(self, predict, num_classes, batch_size=1,
binary_search_steps=9, max_iterations=100,
initial_const=1e-2,
clip_min=0, clip_max=1, loss_fn=None, targeted=False):
super(LBFGSAttack, self).__init__(
predict, loss_fn, clip_min, clip_max)
# XXX: should combine the input loss function with other things
self.num_classes = num_classes
self.batch_size = batch_size
self.binary_search_steps = binary_search_steps
self.max_iterations = max_iterations
self.initial_const = initial_const
self.targeted = targeted
def _update_if_better(
self, adv_img, labs, output, dist, batch_size,
final_l2dists, final_labels, final_advs):
for ii in range(batch_size):
target_label = labs[ii]
output_logits = output[ii]
_, output_label = torch.max(output_logits, 0)
di = dist[ii]
if (di < final_l2dists[ii] and
output_label.item() == target_label):
final_l2dists[ii] = di
final_labels[ii] = output_label
final_advs[ii] = adv_img[ii]
def _update_loss_coeffs(
self, labs, batch_size,
loss_coeffs, coeff_upper_bound, coeff_lower_bound, output):
for ii in range(batch_size):
_, cur_label = torch.max(output[ii], 0)
if cur_label.item() == int(labs[ii]):
coeff_upper_bound[ii] = min(
coeff_upper_bound[ii], loss_coeffs[ii])
if coeff_upper_bound[ii] < UPPER_CHECK:
loss_coeffs[ii] = (coeff_lower_bound[ii] +
coeff_upper_bound[ii]) / 2
else:
coeff_lower_bound[ii] = max(
coeff_lower_bound[ii], loss_coeffs[ii])
if coeff_upper_bound[ii] < UPPER_CHECK:
loss_coeffs[ii] = (coeff_lower_bound[ii] +
coeff_upper_bound[ii]) / 2
else:
loss_coeffs[ii] *= 10
def perturb(self, x, y=None):
from scipy.optimize import fmin_l_bfgs_b
def _loss_fn(adv_x_np, self, x, target, const):
adv_x = torch.from_numpy(
adv_x_np.reshape(x.shape)).float().to(
x.device).requires_grad_()
output = self.predict(adv_x)
loss2 = torch.sum((x - adv_x) ** 2)
loss_fn = F.cross_entropy(output, target, reduction='none')
loss1 = torch.sum(const * loss_fn)
loss = loss1 + loss2
loss.backward()
grad_ret = adv_x.grad.data.cpu().numpy().flatten().astype(float)
loss = loss.data.cpu().numpy().flatten().astype(float)
if not self.targeted:
loss = -loss
return loss, grad_ret
x, y = self._verify_and_process_inputs(x, y)
batch_size = len(x)
coeff_lower_bound = x.new_zeros(batch_size)
coeff_upper_bound = x.new_ones(batch_size) * COEFF_UPPER
loss_coeffs = x.new_ones(batch_size) * self.initial_const
final_l2dists = [L2DIST_UPPER] * batch_size
final_labels = [INVALID_LABEL] * batch_size
final_advs = x.clone()
clip_min = self.clip_min * np.ones(x.shape[:]).astype(float)
clip_max = self.clip_max * np.ones(x.shape[:]).astype(float)
clip_bound = list(zip(clip_min.flatten(), clip_max.flatten()))
for outer_step in range(self.binary_search_steps):
init_guess = x.clone().cpu().numpy().flatten().astype(float)
adv_x, f, _ = fmin_l_bfgs_b(_loss_fn,
init_guess,
args=(self, x.clone(), y, loss_coeffs),
bounds=clip_bound,
maxiter=self.max_iterations,
iprint=0)
adv_x = torch.from_numpy(
adv_x.reshape(x.shape)).float().to(x.device)
l2s = calc_l2distsq(x, adv_x)
output = self.predict(adv_x)
self._update_if_better(
adv_x, y, output.data, l2s, batch_size,
final_l2dists, final_labels, final_advs)
self._update_loss_coeffs(
y, batch_size,
loss_coeffs, coeff_upper_bound, coeff_lower_bound,
output.data)
return final_advs