Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / operator_test / video_input_op_test.py



import os
import shutil
import sys
import tempfile
import unittest

import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import model_helper, workspace


try:
    import lmdb
except ImportError:
    raise unittest.SkipTest("python-lmdb is not installed")


class VideoInputOpTest(unittest.TestCase):
    def create_a_list(self, output_file, line, n):
        # create a list that repeat a line n times
        # used for creating a list file for simple test input
        with open(output_file, "w") as file:
            for _i in range(n):
                file.write(line)

    def create_video_db(self, list_file, output_file, use_list=False):
        # Write to lmdb database...
        LMDB_MAP_SIZE = 1 << 40  # MODIFY
        env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE)
        total_size = 0

        file_name = []
        start_frame = []
        label = []
        index = 0

        with env.begin(write=True) as txn:
            with open(list_file, "r") as data:
                for line in data:
                    p = line.split()
                    file_name = p[0]
                    start_frame = int(p[1])
                    label = int(p[2])

                    if not use_list:
                        with open(file_name, mode="rb") as file:
                            video_data = file.read()
                    else:
                        video_data = file_name

                    tensor_protos = caffe2_pb2.TensorProtos()
                    video_tensor = tensor_protos.protos.add()
                    video_tensor.data_type = 4  # string data
                    video_tensor.string_data.append(video_data)

                    label_tensor = tensor_protos.protos.add()
                    label_tensor.data_type = 2
                    label_tensor.int32_data.append(label)

                    start_frame_tensor = tensor_protos.protos.add()
                    start_frame_tensor.data_type = 2
                    start_frame_tensor.int32_data.append(start_frame)

                    txn.put(
                        "{}".format(index).encode("ascii"),
                        tensor_protos.SerializeToString(),
                    )
                    index = index + 1
                    total_size = total_size + len(video_data) + sys.getsizeof(int)
        return total_size

    # sample one clip randomly from the video
    def test_rgb_with_temporal_jittering(self):
        random_label = np.random.randint(0, 100)
        VIDEO = "/mnt/vol/gfsdataswarm-oregon/users/trandu/sample.avi"
        if not os.path.exists(VIDEO):
            raise unittest.SkipTest("Missing data")
        temp_list = tempfile.NamedTemporaryFile(delete=False).name
        line_str = "{} 0 {}\n".format(VIDEO, random_label)
        self.create_a_list(temp_list, line_str, 16)
        video_db_dir = tempfile.mkdtemp()

        self.create_video_db(temp_list, video_db_dir)
        model = model_helper.ModelHelper(name="Video Loader from LMDB")
        reader = model.CreateDB("sample", db=video_db_dir, db_type="lmdb")

        # build the model
        model.net.VideoInput(
            reader,
            ["data", "label"],
            name="data",
            batch_size=16,
            clip_per_video=1,
            crop_size=112,
            scale_w=171,
            scale_h=128,
            length_rgb=8,
            sampling_rate_rgb=1,
            decode_type=0,
            video_res_type=0,  # scale by scale_h and scale_w
        )

        workspace.RunNetOnce(model.param_init_net)
        workspace.RunNetOnce(model.net)
        data = workspace.FetchBlob("data")
        label = workspace.FetchBlob("label")

        np.testing.assert_equal(label, random_label)
        np.testing.assert_equal(data.shape, [16, 3, 8, 112, 112])
        os.remove(temp_list)
        shutil.rmtree(video_db_dir)

    # sample multiple clips uniformly from the video
    def test_rgb_with_uniform_sampling(self):
        random_label = np.random.randint(0, 100)
        clip_per_video = np.random.randint(2, 11)
        VIDEO = "/mnt/vol/gfsdataswarm-oregon/users/trandu/sample.avi"
        if not os.path.exists(VIDEO):
            raise unittest.SkipTest("Missing data")
        temp_list = tempfile.NamedTemporaryFile(delete=False).name
        line_str = "{} 0 {}\n".format(VIDEO, random_label)
        self.create_a_list(temp_list, line_str, 16)
        video_db_dir = tempfile.mkdtemp()

        self.create_video_db(temp_list, video_db_dir)
        model = model_helper.ModelHelper(name="Video Loader from LMDB")
        reader = model.CreateDB("sample", db=video_db_dir, db_type="lmdb")

        # build the model
        model.net.VideoInput(
            reader,
            ["data", "label"],
            name="data",
            batch_size=3,
            clip_per_video=clip_per_video,
            crop_size=112,
            scale_w=171,
            scale_h=128,
            length_rgb=8,
            sampling_rate_rgb=1,
            decode_type=1,
            video_res_type=0,
        )

        workspace.RunNetOnce(model.param_init_net)
        workspace.RunNetOnce(model.net)
        data = workspace.FetchBlob("data")
        label = workspace.FetchBlob("label")

        np.testing.assert_equal(label, random_label)
        np.testing.assert_equal(data.shape, [3 * clip_per_video, 3, 8, 112, 112])
        os.remove(temp_list)
        shutil.rmtree(video_db_dir)

    # test optical flow
    def test_optical_flow_with_temporal_jittering(self):
        random_label = np.random.randint(0, 100)
        VIDEO = "/mnt/vol/gfsdataswarm-oregon/users/trandu/sample.avi"
        if not os.path.exists(VIDEO):
            raise unittest.SkipTest("Missing data")
        temp_list = tempfile.NamedTemporaryFile(delete=False).name
        line_str = "{} 0 {}\n".format(VIDEO, random_label)
        self.create_a_list(temp_list, line_str, 16)
        video_db_dir = tempfile.mkdtemp()

        self.create_video_db(temp_list, video_db_dir)
        model = model_helper.ModelHelper(name="Video Loader from LMDB")
        reader = model.CreateDB("sample", db=video_db_dir, db_type="lmdb")
        model.net.VideoInput(
            reader,
            ["data", "label"],
            name="data",
            batch_size=16,
            clip_per_video=1,
            crop_size=112,
            scale_w=171,
            scale_h=128,
            length_of=8,
            sampling_rate_of=1,
            frame_gap_of=1,
            decode_type=0,
            video_res_type=0,
            get_rgb=False,
            get_optical_flow=True,
        )

        workspace.RunNetOnce(model.param_init_net)
        workspace.RunNetOnce(model.net)
        data = workspace.FetchBlob("data")
        label = workspace.FetchBlob("label")

        np.testing.assert_equal(label, random_label)
        np.testing.assert_equal(data.shape, [16, 2, 8, 112, 112])
        os.remove(temp_list)
        shutil.rmtree(video_db_dir)

    # test rgb output VideoResType is
    # USE_SHORTER_EDGE
    def test_rgb_use_shorter_edge(self):
        batch_size = 16
        random_label = np.random.randint(0, 100)
        VIDEO = "/mnt/vol/gfsdataswarm-oregon/users/trandu/sample.avi"
        if not os.path.exists(VIDEO):
            raise unittest.SkipTest("Missing data")
        temp_list = tempfile.NamedTemporaryFile(delete=False).name
        line_str = "{} 0 {}\n".format(VIDEO, random_label)
        self.create_a_list(temp_list, line_str, batch_size)
        video_db_dir = tempfile.mkdtemp()

        self.create_video_db(temp_list, video_db_dir)
        model = model_helper.ModelHelper(name="Video Loader from LMDB")
        reader = model.CreateDB("sample", db=video_db_dir, db_type="lmdb")
        model.net.VideoInput(
            reader,
            ["data", "label"],
            name="data",
            batch_size=batch_size,
            clip_per_video=1,
            crop_size=112,
            scale_w=171,
            scale_h=128,
            length_of=8,
            frame_gap_of=1,
            decode_type=0,
            video_res_type=1,  # use shorter edge
            get_rgb=True,
            length_rgb=8,
            short_edge=112,
        )

        workspace.RunNetOnce(model.param_init_net)
        workspace.RunNetOnce(model.net)
        data = workspace.FetchBlob("data")
        label = workspace.FetchBlob("label")

        np.testing.assert_equal(label.shape, [batch_size])
        for i in range(batch_size):
            np.testing.assert_equal(label[i], random_label)
        np.testing.assert_equal(data.shape, [batch_size, 3, 8, 112, 112])
        os.remove(temp_list)
        shutil.rmtree(video_db_dir)

    # test optical flow output VideoResType is
    # USE_SHORTER_EDGE
    def test_optical_flow_use_shorter_edge(self):
        batch_size = 16
        random_label = np.random.randint(0, 100)
        VIDEO = "/mnt/vol/gfsdataswarm-oregon/users/trandu/sample.avi"
        if not os.path.exists(VIDEO):
            raise unittest.SkipTest("Missing data")
        temp_list = tempfile.NamedTemporaryFile(delete=False).name
        line_str = "{} 0 {}\n".format(VIDEO, random_label)
        self.create_a_list(temp_list, line_str, batch_size)
        video_db_dir = tempfile.mkdtemp()

        self.create_video_db(temp_list, video_db_dir)
        model = model_helper.ModelHelper(name="Video Loader from LMDB")
        reader = model.CreateDB("sample", db=video_db_dir, db_type="lmdb")
        model.net.VideoInput(
            reader,
            ["data", "label"],
            name="data",
            batch_size=batch_size,
            clip_per_video=1,
            crop_size=112,
            scale_w=171,
            scale_h=128,
            length_of=8,
            sampling_rate_of=1,
            frame_gap_of=1,
            decode_type=0,
            video_res_type=1,  # use shorter edge
            get_rgb=False,
            get_optical_flow=True,
            short_edge=112,
        )

        workspace.RunNetOnce(model.param_init_net)
        workspace.RunNetOnce(model.net)
        data = workspace.FetchBlob("data")
        label = workspace.FetchBlob("label")

        np.testing.assert_equal(label.shape, [batch_size])
        for i in range(batch_size):
            np.testing.assert_equal(label[i], random_label)
        np.testing.assert_equal(data.shape, [batch_size, 2, 8, 112, 112])
        os.remove(temp_list)
        shutil.rmtree(video_db_dir)


if __name__ == "__main__":
    unittest.main()