import os
import tarfile
import collections
from .vision import VisionDataset
import xml.etree.ElementTree as ET
from PIL import Image
from typing import Any, Callable, Dict, Optional, Tuple, List
from .utils import download_and_extract_archive, verify_str_arg
import warnings
DATASET_YEAR_DICT = {
'2012': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
'base_dir': os.path.join('VOCdevkit', 'VOC2012')
},
'2011': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
'filename': 'VOCtrainval_25-May-2011.tar',
'md5': '6c3384ef61512963050cb5d687e5bf1e',
'base_dir': os.path.join('TrainVal', 'VOCdevkit', 'VOC2011')
},
'2010': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
'filename': 'VOCtrainval_03-May-2010.tar',
'md5': 'da459979d0c395079b5c75ee67908abb',
'base_dir': os.path.join('VOCdevkit', 'VOC2010')
},
'2009': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
'filename': 'VOCtrainval_11-May-2009.tar',
'md5': '59065e4b188729180974ef6572f6a212',
'base_dir': os.path.join('VOCdevkit', 'VOC2009')
},
'2008': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '2629fa636546599198acfcfbfcf1904a',
'base_dir': os.path.join('VOCdevkit', 'VOC2008')
},
'2007': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'filename': 'VOCtrainval_06-Nov-2007.tar',
'md5': 'c52e279531787c972589f7e41ab4ae64',
'base_dir': os.path.join('VOCdevkit', 'VOC2007')
},
'2007-test': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
'filename': 'VOCtest_06-Nov-2007.tar',
'md5': 'b6e924de25625d8de591ea690078ad9f',
'base_dir': os.path.join('VOCdevkit', 'VOC2007')
}
}
class _VOCBase(VisionDataset):
_SPLITS_DIR: str
_TARGET_DIR: str
_TARGET_FILE_EXT: str
def __init__(
self,
root: str,
year: str = "2012",
image_set: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms, transform, target_transform)
if year == "2007-test":
if image_set == "test":
warnings.warn(
"Acessing the test image set of the year 2007 with year='2007-test' is deprecated. "
"Please use the combination year='2007' and image_set='test' instead."
)
year = "2007"
else:
raise ValueError(
"In the test image set of the year 2007 only image_set='test' is allowed. "
"For all other image sets use year='2007' instead."
)
self.year = year
valid_image_sets = ["train", "trainval", "val"]
if year == "2007":
valid_image_sets.append("test")
key = "2007-test"
else:
key = year
self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
dataset_year_dict = DATASET_YEAR_DICT[key]
self.url = dataset_year_dict["url"]
self.filename = dataset_year_dict["filename"]
self.md5 = dataset_year_dict["md5"]
base_dir = dataset_year_dict["base_dir"]
voc_root = os.path.join(self.root, base_dir)
if download:
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
image_dir = os.path.join(voc_root, "JPEGImages")
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
target_dir = os.path.join(voc_root, self._TARGET_DIR)
self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
assert len(self.images) == len(self.targets)
def __len__(self) -> int:
return len(self.images)
class VOCSegmentation(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Segmentation"
_TARGET_DIR = "SegmentationClass"
_TARGET_FILE_EXT = ".png"
@property
def masks(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
class VOCDetection(_VOCBase):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
(default: alphabetic indexing of VOC's 20 classes).
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, required): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Main"
_TARGET_DIR = "Annotations"
_TARGET_FILE_EXT = ".xml"
@property
def annotations(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert("RGB")
target = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]:
voc_dict: Dict[str, Any] = {}
children = list(node)
if children:
def_dic: Dict[str, Any] = collections.defaultdict(list)
for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items():
def_dic[ind].append(v)
if node.tag == "annotation":
def_dic["object"] = [def_dic["object"]]
voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
if node.text:
text = node.text.strip()
if not children:
voc_dict[node.tag] = text
return voc_dict