Repository URL to install this package:
|
Version:
0.2.4 ▾
|
advertorch
/
functional.py
|
|---|
# Copyright (c) 2018-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
try:
from cStringIO import StringIO as BytesIO
except ImportError:
from io import BytesIO
import torch
from torchvision import transforms
from PIL import Image
_to_pil_image = transforms.ToPILImage()
_to_tensor = transforms.ToTensor()
class FloatToIntSqueezing(torch.autograd.Function):
@staticmethod
def forward(ctx, x, max_int, vmin, vmax):
# here assuming 0 =< x =< 1
x = (x - vmin) / (vmax - vmin)
x = torch.round(x * max_int) / max_int
return x * (vmax - vmin) + vmin
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError(
"backward not implemented", FloatToIntSqueezing)
class JPEGEncodingDecoding(torch.autograd.Function):
@staticmethod
def forward(ctx, x, quality):
lst_img = []
for img in x:
img = _to_pil_image(img.detach().clone().cpu())
virtualpath = BytesIO()
img.save(virtualpath, 'JPEG', quality=quality)
lst_img.append(_to_tensor(Image.open(virtualpath)))
return x.new_tensor(torch.stack(lst_img))
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError(
"backward not implemented", JPEGEncodingDecoding)