Source code for qsprpred.utils.tests

import time
from concurrent import futures

from parameterized import parameterized

from .parallel import batched_generator, MultiprocessingJITGenerator, \
    PebbleJITGenerator, ThreadsJITGenerator
from .testing.base import QSPRTestCase


[docs]class TestMultiProcGenerators(QSPRTestCase):
[docs] @staticmethod def func(x): return x ** 2
[docs] @staticmethod def func_batched(x): return [i ** 2 for i in x]
[docs] @staticmethod def func_timeout(x): time.sleep(x) return x ** 2
[docs] @staticmethod def func_args(x, *args, **kwargs): return x, args, kwargs
@parameterized.expand([ (None, MultiprocessingJITGenerator), (1, PebbleJITGenerator), (None, MultiprocessingJITGenerator), ]) def testSimple(self, timeout, pool_type): generator = (x for x in range(10)) p_generator = pool_type(self.nCPU) if not timeout else pool_type( self.nCPU, timeout=timeout ) self.assertListEqual( [x ** 2 for x in range(10)], sorted(p_generator( generator, self.func, )) ) @parameterized.expand([ (None, MultiprocessingJITGenerator), (1, PebbleJITGenerator), (None, MultiprocessingJITGenerator), ]) def testBatched(self, timeout, pool_type): generator = batched_generator(range(10), 2) p_generator = pool_type(self.nCPU) if not timeout else pool_type( self.nCPU, timeout=timeout ) self.assertListEqual( [[0, 1], [4, 9], [16, 25], [36, 49], [64, 81]], sorted(p_generator( generator, self.func_batched )) )
[docs] def testTimeout(self): generator = (x for x in [1, 2, 10]) timeout = 4 p_generator = PebbleJITGenerator(self.nCPU, timeout=timeout) result = list(p_generator( generator, self.func_timeout )) self.assertListEqual([1, 4], result[0:-1]) self.assertIsInstance(result[-1], futures.TimeoutError) self.assertTrue(str(timeout) in str(result[-1]))
@parameterized.expand([ ((0,), {"A": 1}, MultiprocessingJITGenerator), (None, {"A": 1}, MultiprocessingJITGenerator), ((0,), None, MultiprocessingJITGenerator), ]) def testArgs(self, args, kwargs, pool_type): generator = (x for x in range(10)) p_generator = pool_type(self.nCPU) result = list(p_generator( generator, self.func_args, *args or (), **kwargs or {}, )) result = sorted(result, key=lambda x: x[0]) for idx, res in enumerate(result): self.assertEqual( (idx, args if args else (), kwargs if kwargs else {}), res )
[docs]class TestThreadedGenerators(QSPRTestCase): """Test processing using a pool of threads."""
[docs] @staticmethod def func(x): time.sleep(1) return x ** 2
[docs] @staticmethod def func_batched(x): time.sleep(1) return [i ** 2 for i in x]
[docs] def testSimple(self): generator = (x for x in range(10)) p_generator = ThreadsJITGenerator(self.nCPU) self.assertListEqual( [x ** 2 for x in range(10)], sorted(p_generator( generator, self.func, )) )
[docs] def testBatched(self): generator = batched_generator(range(10), 2) p_generator = ThreadsJITGenerator(self.nCPU) self.assertListEqual( [[0, 1], [4, 9], [16, 25], [36, 49], [64, 81]], sorted(p_generator( generator, self.func_batched )) )