Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ package / _mock_zipreader.py

import torch
from glob import glob
import os.path
from typing import List, Any

_storages : List[Any] = [
    torch.DoubleStorage,
    torch.FloatStorage,
    torch.LongStorage,
    torch.IntStorage,
    torch.ShortStorage,
    torch.CharStorage,
    torch.ByteStorage,
    torch.BoolStorage,
]
_dtype_to_storage = {
    data_type(0).dtype: data_type for data_type in _storages
}

# because get_storage_from_record returns a tensor!?
class _HasStorage(object):
    def __init__(self, storage):
        self._storage = storage

    def storage(self):
        return self._storage


class MockZipReader(object):
    def __init__(self, directory):
        self.directory = directory

    def get_record(self, name):
        filename = f'{self.directory}/{name}'
        with open(filename, 'rb') as f:
            return f.read()

    def get_storage_from_record(self, name, numel, dtype):
        storage = _dtype_to_storage[dtype]
        filename = f'{self.directory}/{name}'
        return _HasStorage(storage.from_file(filename=filename, size=numel))

    def get_all_records(self, ):
        files = []
        for filename in glob(f'{self.directory}/**', recursive=True):
            if not os.path.isdir(filename):
                files.append(filename[len(self.directory) + 1:])
        return files