Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
gtsam / tests / test_DSFMap.py
Size: Mime:
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved

See LICENSE for the license information

Unit tests for Disjoint Set Forest.
Author: Frank Dellaert & Varun Agrawal & John Lambert
"""
# pylint: disable=invalid-name, no-name-in-module, no-member

from __future__ import print_function

import unittest
from typing import Tuple

from gtsam import DSFMapIndexPair, IndexPair, IndexPairSetAsArray
from gtsam.utils.test_case import GtsamTestCase


class TestDSFMap(GtsamTestCase):
    """Tests for DSFMap."""

    def test_all(self) -> None:
        """Test everything in DFSMap."""

        def key(index_pair) -> Tuple[int, int]:
            return index_pair.i(), index_pair.j()

        dsf = DSFMapIndexPair()
        pair1 = IndexPair(1, 18)
        self.assertEqual(key(dsf.find(pair1)), key(pair1))
        pair2 = IndexPair(2, 2)

        # testing the merge feature of dsf
        dsf.merge(pair1, pair2)
        self.assertEqual(key(dsf.find(pair1)), key(dsf.find(pair2)))

    def test_sets(self) -> None:
        """Ensure that pairs are merged correctly during Union-Find.

        An IndexPair (i,k) representing a unique key might represent the
        k'th detected keypoint in image i. For the data below, merging such
        measurements into feature tracks across frames should create 2 distinct sets.
        """
        dsf = DSFMapIndexPair()
        dsf.merge(IndexPair(0, 1), IndexPair(1, 2))
        dsf.merge(IndexPair(0, 1), IndexPair(3, 4))
        dsf.merge(IndexPair(4, 5), IndexPair(6, 8))
        sets = dsf.sets()

        merged_sets = set()

        for i in sets:
            set_keys = []
            s = sets[i]
            for val in IndexPairSetAsArray(s):
                set_keys.append((val.i(), val.j()))
            merged_sets.add(tuple(set_keys))

        # fmt: off
        expected_sets = {
            ((0, 1), (1, 2), (3, 4)), # set 1
            ((4, 5), (6, 8)) # set 2
        }
        # fmt: on
        assert expected_sets == merged_sets


if __name__ == "__main__":
    unittest.main()