Source code for qsprpred.extra.gpu.utils.parallel

from qsprpred.utils.parallel import MultiprocessingJITGenerator


[docs]class TorchJITGenerator(MultiprocessingJITGenerator): """A variant of the `MultiprocessingPoolGenerator` that uses the `torch.multiprocessing.Pool` instead of the standard `multiprocessing.Pool`. This is needed when the parallel processing is done with PyTorch tensors or models, which require the `torch.multiprocessing` and using the `spawn` start method. """
[docs] def getPool(self): from torch.multiprocessing import Pool, set_start_method set_start_method("spawn", force=True) return Pool(self.nWorkers)