import time
from concurrent import futures
from parameterized import parameterized
from .parallel import (
MultiprocessingJITGenerator,
PebbleJITGenerator,
ThreadsJITGenerator,
batched_generator,
)
from .testing.base import QSPRTestCase
[docs]
class TestMultiProcGenerators(QSPRTestCase):
[docs]
@staticmethod
def func(x):
return x ** 2
[docs]
@staticmethod
def func_batched(x):
return [i ** 2 for i in x]
[docs]
@staticmethod
def func_timeout(x):
time.sleep(x)
return x ** 2
[docs]
@staticmethod
def func_args(x, *args, **kwargs):
return x, args, kwargs
@parameterized.expand(
[
(None, MultiprocessingJITGenerator),
(1, PebbleJITGenerator),
(None, MultiprocessingJITGenerator),
]
)
def testSimple(self, timeout, pool_type):
n = 100
generator = (x for x in range(n))
p_generator = (
pool_type(self.nCPU)
if not timeout else pool_type(self.nCPU, timeout=timeout)
)
self.assertListEqual(
[x ** 2 for x in range(n)],
sorted(p_generator(
generator,
self.func,
)),
)
@parameterized.expand(
[
(None, MultiprocessingJITGenerator),
(1, PebbleJITGenerator),
(None, MultiprocessingJITGenerator),
]
)
def testBatched(self, timeout, pool_type):
generator = batched_generator(range(10), 2)
p_generator = (
pool_type(self.nCPU)
if not timeout else pool_type(self.nCPU, timeout=timeout)
)
self.assertListEqual(
[[0, 1], [4, 9], [16, 25], [36, 49], [64, 81]],
sorted(p_generator(generator, self.func_batched)),
)
[docs]
def testTimeout(self):
generator = (x for x in [1, 2, 10])
timeout = 4
p_generator = PebbleJITGenerator(self.nCPU, timeout=timeout)
result = list(p_generator(generator, self.func_timeout))
self.assertListEqual([1, 4], result[0:-1])
self.assertIsInstance(result[-1], futures.TimeoutError)
self.assertTrue(str(timeout) in str(result[-1]))
@parameterized.expand(
[
((0,), {
"A": 1
}, MultiprocessingJITGenerator),
(None, {
"A": 1
}, MultiprocessingJITGenerator),
((0,), None, MultiprocessingJITGenerator),
]
)
def testArgs(self, args, kwargs, pool_type):
generator = (x for x in range(10))
p_generator = pool_type(self.nCPU)
result = list(
p_generator(
generator,
self.func_args,
*args or (),
**kwargs or {},
)
)
result = sorted(result, key=lambda x: x[0])
for idx, res in enumerate(result):
self.assertEqual((idx, args if args else (), kwargs if kwargs else {}), res)
[docs]
class TestThreadedGenerators(QSPRTestCase):
"""Test processing using a pool of threads."""
[docs]
@staticmethod
def func(x):
time.sleep(1)
return x ** 2
[docs]
@staticmethod
def func_batched(x):
time.sleep(1)
return [i ** 2 for i in x]
[docs]
def testSimple(self):
generator = (x for x in range(10))
p_generator = ThreadsJITGenerator(self.nCPU)
self.assertListEqual(
[x ** 2 for x in range(10)],
sorted(p_generator(
generator,
self.func,
)),
)
[docs]
def testBatched(self):
generator = batched_generator(range(10), 2)
p_generator = ThreadsJITGenerator(self.nCPU)
self.assertListEqual(
[[0, 1], [4, 9], [16, 25], [36, 49], [64, 81]],
sorted(p_generator(generator, self.func_batched)),
)