Repository URL to install this package:
|
Version:
4.2.0 ▾
|
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""
Automated tests for checking various matutils functions.
"""
import logging
import unittest
import numpy as np
from gensim.test.utils import datapath
from gensim.models.keyedvectors import KeyedVectors
class TestDataType(unittest.TestCase):
def load_model(self, datatype):
path = datapath('high_precision.kv.txt')
kv = KeyedVectors.load_word2vec_format(path, binary=False,
datatype=datatype)
return kv
def test_high_precision(self):
kv = self.load_model(np.float64)
self.assertAlmostEqual(kv['horse.n.01'][0], -0.0008546282343595379)
self.assertEqual(kv['horse.n.01'][0].dtype, np.float64)
def test_medium_precision(self):
kv = self.load_model(np.float32)
self.assertAlmostEqual(kv['horse.n.01'][0], -0.00085462822)
self.assertEqual(kv['horse.n.01'][0].dtype, np.float32)
def test_low_precision(self):
kv = self.load_model(np.float16)
self.assertAlmostEqual(kv['horse.n.01'][0], -0.00085449)
self.assertEqual(kv['horse.n.01'][0].dtype, np.float16)
def test_type_conversion(self):
path = datapath('high_precision.kv.txt')
binary_path = datapath('high_precision.kv.bin')
model1 = KeyedVectors.load_word2vec_format(path, datatype=np.float16)
model1.save_word2vec_format(binary_path, binary=True)
model2 = KeyedVectors.load_word2vec_format(binary_path, datatype=np.float64, binary=True)
self.assertAlmostEqual(model1["horse.n.01"][0], np.float16(model2["horse.n.01"][0]))
self.assertEqual(model1["horse.n.01"][0].dtype, np.float16)
self.assertEqual(model2["horse.n.01"][0].dtype, np.float64)
if __name__ == '__main__':
logging.root.setLevel(logging.WARNING)
unittest.main()