Source code for qsprpred.extra.gpu.utils.tests

import time
from unittest import skipIf

import torch
from parameterized import parameterized

from qsprpred.extra.gpu.utils.parallel import TorchJITGenerator
from qsprpred.utils.parallel import batched_generator, ThreadsJITGenerator
from qsprpred.utils.testing.base import QSPRTestCase


[docs]@skipIf(not torch.cuda.is_available(), "CUDA not available. Skipping...") class TestMultiGPUGenerators(QSPRTestCase):
[docs] @staticmethod def func(x, gpu=None): assert gpu is not None device = torch.device(f"cuda:{gpu}") ret = (torch.tensor([x], device=device) ** 2).item() time.sleep(1) return ret
[docs] @staticmethod def func_batched(x, gpu=None): assert gpu is not None device = torch.device(f"cuda:{gpu}") time.sleep(1) return (torch.tensor(x, device=device) ** 2).tolist()
@parameterized.expand([ (1,), (2,), ]) def testSimple(self, jobs_per_gpu): generator = (x for x in range(10)) p_generator = TorchJITGenerator( len(self.GPUs), use_gpus=self.GPUs, jobs_per_gpu=jobs_per_gpu, worker_type="gpu" ) self.assertListEqual( [x ** 2 for x in range(10)], sorted(p_generator( generator, self.func, )) ) @parameterized.expand([ (1,), (2,), ]) def testBatched(self, jobs_per_gpu): generator = batched_generator(range(10), 2) p_generator = TorchJITGenerator( len(self.GPUs), use_gpus=self.GPUs, jobs_per_gpu=jobs_per_gpu, worker_type="gpu" ) self.assertListEqual( [[0, 1], [4, 9], [16, 25], [36, 49], [64, 81]], sorted(p_generator( generator, self.func_batched )) )
[docs]class TestThreadedGeneratorsGPU(QSPRTestCase): """Test processing using a pool of threads."""
[docs] @staticmethod def gpu_func(x, gpu=None): assert gpu is not None device = torch.device(f"cuda:{gpu}") ret = (torch.tensor([x], device=device) ** 2).item() time.sleep(1) return ret
[docs] @staticmethod def gpu_func_batched(x, gpu=None): assert gpu is not None device = torch.device(f"cuda:{gpu}") time.sleep(1) return (torch.tensor(x, device=device) ** 2).tolist()
[docs] @skipIf(not torch.cuda.is_available(), "CUDA not available. Skipping...") def testSimpleGPU(self): generator = (x for x in range(10)) p_generator = ThreadsJITGenerator( len(self.GPUs), use_gpus=self.GPUs, worker_type="gpu", jobs_per_gpu=2 ) self.assertListEqual( [x ** 2 for x in range(10)], sorted(p_generator( generator, self.gpu_func, )) )
[docs] @skipIf(not torch.cuda.is_available(), "CUDA not available. Skipping...") def testBatchedGPU(self): generator = batched_generator(range(10), 2) p_generator = ThreadsJITGenerator( len(self.GPUs), use_gpus=self.GPUs, worker_type="gpu" ) self.assertListEqual( [[0, 1], [4, 9], [16, 25], [36, 49], [64, 81]], sorted(p_generator( generator, self.gpu_func_batched )) )