Skip to content
Snippets Groups Projects
Select Git revision
  • ab1788292912fff7689d0be1d91fd5f36a2e1d85
  • master default protected
2 results

index.html

Blame
  • test_util.py 7.68 KiB
    import unittest
    import os
    import logging
    
    import numpy as np
    import numpy.testing
    
    from pycgtool.util import tuple_equivalent, extend_graph_chain, stat_moments, transpose_and_sample
    from pycgtool.util import dir_up, backup_file, sliding, r_squared, dist_with_pbc
    from pycgtool.util import SimpleEnum, FixedFormatUnpacker
    
    
    class UtilTest(unittest.TestCase):
        def test_tuple_equivalent(self):
            t1 = (0, 1, 2)
            t2 = (0, 1, 2)
            self.assertTrue(tuple_equivalent(t1, t2))
            t2 = (2, 1, 0)
            self.assertTrue(tuple_equivalent(t1, t2))
            t2 = (2, 1, 3)
            self.assertFalse(tuple_equivalent(t1, t2))
    
        def test_dist_with_pbc(self):
            pos_a = np.array([1., 1., 1.])
            pos_b = np.array([9., 9., 9.])
            numpy.testing.assert_equal(np.array([8., 8., 8.]),
                                       dist_with_pbc(pos_a, pos_b, np.array([0., 0., 0.])))
            numpy.testing.assert_equal(np.array([8., 8., 8.]),
                                       dist_with_pbc(pos_a, pos_b, np.array([20., 20., 20.])))
            numpy.testing.assert_equal(np.array([-2., -2., -2.]),
                                       dist_with_pbc(pos_a, pos_b, np.array([10., 10., 10.])))
    
        def test_triplets_from_pairs(self):
            pairs = [(0, 1), (1, 2), (2, 3)]
            result = [(0, 1, 2), (1, 2, 3)]
            self.assertEqual(result, sorted(extend_graph_chain(pairs, pairs)))
            pairs = [(0, 1), (1, 2), (2, 3), (3, 0)]
            result = [(0, 1, 2), (1, 0, 3), (1, 2, 3), (2, 3, 0)]
            self.assertEqual(result, sorted(extend_graph_chain(pairs, pairs)))
    
        def test_triplets_from_pairs_multires(self):
            pairs = [("a", "b"), ("b", "c"), ("c", "d"), ("d", "+a")]
            result = [("a", "b", "c"), ("b", "c", "d"), ("c", "d", "+a"), ("d", "+a", "+b")]
            self.assertEqual(result, sorted(extend_graph_chain(pairs, pairs)))
    
        def test_quadruplets_from_pairs(self):
            pairs = [(0, 1), (1, 2), (2, 3)]
            result = [(0, 1, 2, 3)]
            triplets = extend_graph_chain(pairs, pairs)
            self.assertEqual(result, sorted(extend_graph_chain(triplets, pairs)))
            pairs = [(0, 1), (1, 2), (2, 3), (3, 0)]
            triplets = extend_graph_chain(pairs, pairs)
            result = [(0, 1, 2, 3), (1, 0, 3, 2), (1, 2, 3, 0), (2, 1, 0, 3)]
            self.assertEqual(result, sorted(extend_graph_chain(triplets, pairs)))
    
        def test_stat_moments(self):
            t1 = [3, 3, 3, 3, 3]
            t2 = [1, 2, 3, 4, 5]
            np.testing.assert_allclose(np.array([3, 0]), stat_moments(t1))
            np.testing.assert_allclose(np.array([3, 2]), stat_moments(t2))
    
        def test_dir_up(self):
            path = os.path.realpath(__file__)
            self.assertEqual(path, dir_up(path, 0))
            self.assertEqual(os.path.dirname(path), dir_up(path))
            self.assertEqual(os.path.dirname(os.path.dirname(path)), dir_up(path, 2))
    
        def test_backup_file(self):
            try:
                os.remove("testfile")
                os.remove("#testfile.1#")
                os.remove("#testfile.2#")
            except OSError:
                pass
    
            logging.disable(logging.WARNING)
            open("testfile", "a").close()
            self.assertEqual("#testfile.1#", backup_file("testfile"))
            open("testfile", "a").close()
            self.assertTrue(os.path.exists("#testfile.1#"))
            self.assertEqual("#testfile.2#", backup_file("testfile"))
            open("testfile", "a").close()
            self.assertTrue(os.path.exists("#testfile.2#"))
            logging.disable(logging.NOTSET)
    
            os.remove("testfile")
            os.remove("#testfile.1#")
            os.remove("#testfile.2#")
    
        def test_sliding(self):
            l = [0, 1, 2, 3, 4]
            res = [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, None)]
            for res, pair in zip(res, sliding(l)):
                self.assertEqual(res, pair)
    
        def test_r_squared(self):
            ref = [i for i in range(5)]
            fit = ref
            self.assertEqual(1, r_squared(ref, fit))
            fit = [2 for _ in range(5)]
            self.assertEqual(0, r_squared(ref, fit))
            fit = [i for i in range(1, 6)]
            self.assertEqual(0.5, r_squared(ref, fit))
    
        def test_transpose_and_sample_no_sample(self):
            l = [(1, 2), (3, 4), (5, 6)]
            l_t = [(1, 3, 5), (2, 4, 6)]
            self.assertEqual(l_t, transpose_and_sample(l, None))
    
        def test_transpose_and_sample(self):
            l = [(1, 2), (3, 4), (5, 6)]
            l_t = [(1, 3, 5), (2, 4, 6)]
    
            l_t_test = transpose_and_sample(l, n=1)
            self.assertEqual(1, len(l_t_test))
            self.assertIn(l_t_test[0], l_t)
    
        def test_simple_enum(self):
            enum = SimpleEnum.enum("enum", ["one", "two", "three"])
            self.assertTrue(enum.one == enum.one)
            self.assertTrue(enum.one.compare_value(enum.one))
    
            self.assertFalse(enum.two == enum.three)
            self.assertFalse(enum.two.compare_value(enum.three))
    
            with self.assertRaises(AttributeError):
                _ = enum.four
            with self.assertRaises(AttributeError):
                enum.one = 2
    
            enum2 = SimpleEnum.enum("enum2", ["one", "two", "three"])
            with self.assertRaises(TypeError):
                assert enum2.one == enum.one
    
            self.assertTrue("one" in enum)
            self.assertFalse("four" in enum)
    
        def test_simple_enum_values(self):
            enum = SimpleEnum.enum_from_dict("enum", {"one": 111,
                                                      "two": 111,
                                                      "three": 333})
            self.assertTrue(enum.one == enum.one)
            self.assertTrue(enum.one.compare_value(enum.one))
    
            self.assertFalse(enum.one == enum.two)
            self.assertTrue(enum.one.compare_value(enum.two))
    
            self.assertFalse(enum.two == enum.three)
            self.assertFalse(enum.two.compare_value(enum.three))
    
            with self.assertRaises(AttributeError):
                _ = enum.four
            with self.assertRaises(AttributeError):
                enum.one = 2
    
            enum2 = SimpleEnum.enum("enum2", ["one", "two", "three"])
            with self.assertRaises(TypeError):
                assert enum2.one == enum.one
    
            self.assertTrue("one" in enum)
            self.assertEqual(111, enum.one.value)
    
            self.assertFalse("four" in enum)
    
        def test_fixed_format_unpacker_c(self):
            unpacker = FixedFormatUnpacker("%-4d%5s%4.1f")
            toks = unpacker.unpack("1234hello12.3")
            self.assertEqual(3, len(toks))
            self.assertEqual(1234, toks[0])
            self.assertEqual("hello", toks[1])
            self.assertAlmostEqual(12.3, toks[2])
    
        def test_fixed_format_unpacker_fortran(self):
            unpacker = FixedFormatUnpacker("I4,A5,F4.1",
                                           FixedFormatUnpacker.FormatStyle.Fortran)
            toks = unpacker.unpack("1234hello12.3")
            self.assertEqual(3, len(toks))
            self.assertEqual(1234, toks[0])
            self.assertEqual("hello", toks[1])
            self.assertAlmostEqual(12.3, toks[2])
    
        def test_fixed_format_unpacker_fortran_space(self):
            unpacker = FixedFormatUnpacker("I4,X3,A5,X2,F4.1",
                                           FixedFormatUnpacker.FormatStyle.Fortran)
            toks = unpacker.unpack("1234 x hello x12.3")
            self.assertEqual(3, len(toks))
            self.assertEqual(1234, toks[0])
            self.assertEqual("hello", toks[1])
            self.assertAlmostEqual(12.3, toks[2])
    
        def test_fixed_format_unpacker_fortran_repeat(self):
            unpacker = FixedFormatUnpacker("2I2,X3,A5,2X,F4.1",
                                           FixedFormatUnpacker.FormatStyle.Fortran)
            toks = unpacker.unpack("1234 x hello x12.3")
            self.assertEqual(4, len(toks))
            self.assertEqual(12, toks[0])
            self.assertEqual(34, toks[1])
            self.assertEqual("hello", toks[2])
            self.assertAlmostEqual(12.3, toks[3])
    
    
    if __name__ == '__main__':
        unittest.main()