import os
import os.path
import hashlib
import gzip
import errno
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
import zipfile
import torch
from torch.utils.model_zoo import tqdm
def gen_bar_updater() -> Callable[[int, int, int], None]:
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
import urllib
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
# check if file is already present locally
if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else: # download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files
def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
return "Google Drive - Quota exceeded" in response.text
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root.
Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
session = requests.Session()
response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
if _quota_exceeded(response):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)
_save_response_content(response, fpath)
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
def _is_tarxz(filename: str) -> bool:
return filename.endswith(".tar.xz")
def _is_tar(filename: str) -> bool:
return filename.endswith(".tar")
def _is_targz(filename: str) -> bool:
return filename.endswith(".tar.gz")
def _is_tgz(filename: str) -> bool:
return filename.endswith(".tgz")
def _is_gzip(filename: str) -> bool:
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
def _is_zip(filename: str) -> bool:
return filename.endswith(".zip")
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
if to_path is None:
to_path = os.path.dirname(from_path)
if _is_tar(from_path):
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished:
os.remove(from_path)
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'"
T = TypeVar("T", str, bytes)
def verify_str_arg(
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)
if valid_values is None:
return value
if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = ("Unknown value '{value}' for argument {arg}. "
"Valid values are {{{valid_values}}}.")
msg = msg.format(value=value, arg=arg,
valid_values=iterable_to_str(valid_values))
raise ValueError(msg)
return value