#!/usr/bin/env python
import random
import torch
from torch.optim import Adam
from drugex import utils, DEFAULT_DEVICE, DEFAULT_GPUS
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
from drugex.logs import logger
from drugex.training.explorers.interfaces import FragExplorer
from drugex.training.generators.utils import pad_mask, tri_mask
[docs]class FragSequenceExplorer(FragExplorer):
"""
`Explorer` to optimize a sequence-based fragment-using agent with the given `Environment`.
"""
def __init__(self, agent, env=None, crover=None, mutate=None, batch_size=128, epsilon=0.1, beta=0.0, n_samples=-1, optim=None, device=DEFAULT_DEVICE, use_gpus=DEFAULT_GPUS, no_multifrag_smiles=True):
super(FragSequenceExplorer, self).__init__(agent, env, mutate, crover, no_multifrag_smiles, batch_size, epsilon, beta, n_samples, device=device, use_gpus=use_gpus)
"""
Parameters
----------
agent : models.SequenceTransformer
The agent network which is optimised to generates the desired molecules.
env : drugex.trainining.interfaces.Environment
The environment which provides the reward and judge if the genrated molecule is valid and desired.
crover : models.SequenceTransformer
The iteratively updated network which increases the exploitation of the chemical space.(Not used in this class)
mutate : models.SequenceTransformer
The pre-trained network which increases the exploration of the chemical space.
batch_size : int
The batch size for the training.
epsilon : float
The probability of using the `mutate` network.
beta : float
The baseline for the reward.
n_samples : int
The number of molecules to be generated in each epoch.
optim : torch.optim
The optimizer to be used for the training.
device : torch.device
The device to be used for the training.
use_gpus : tuple
The GPUs to be used for the training.
no_multifrag_smiles : bool
If True, the molecules with multiple fragments are not considered as valid molecules.
"""
self.optim = utils.ScheduledOptim(
Adam(self.parameters(), betas=(0.9, 0.98), eps=1e-9), 1.0, 512) if not optim else optim
[docs] def forward(self, src):
"""
Forward pass of the agent.
Parameters
----------
src : torch.Tensor
TODO: check the shape of the input tensor.
Returns
-------
torch.Tensor
TODO: check the shape of the input tensor.
"""
seq_len = self.agent.voc_trg.max_len + self.agent.voc_trg.max_len
out = torch.zeros(len(src), seq_len).long().to(self.device)
out[:, :src.size(1)] = src
is_end = torch.zeros(len(src)).bool().to(self.device)
for step in range(self.agent.voc_trg.max_len): # decode up to max length
sub = out[:, :src.size(1) + step]
key_mask = pad_mask(sub, self.agent.pad_idx)
atn_mask = tri_mask(sub)
rand = torch.rand(1)
if self.epsilon < rand <= 0.5 and self.crover is not None:
dec = self.crover.gpt2(sub.transpose(0, 1), key_mask=key_mask, atn_mask=atn_mask)
elif rand < self.epsilon and self.mutate is not None:
dec = self.mutate.gpt2(sub.transpose(0, 1), key_mask=key_mask, atn_mask=atn_mask)
else:
dec = self.agent.gpt2(sub.transpose(0, 1), key_mask=key_mask, atn_mask=atn_mask)
proba = dec[-1,:, :].softmax(dim=-1)
# sampling based on output probability distribution
x = torch.multinomial(proba, 1).view(-1)
x[is_end] = self.agent.voc_trg.tk2ix['_']
is_end |= x == self.agent.voc_trg.tk2ix['EOS']
out[:, src.size(1) + step] = x
if is_end.all(): break
return out[:, self.agent.voc_trg.max_len:].detach()
[docs] def getBatchOutputs(self, net, src):
"""
Outputs (frags, smiles) and loss of the agent for a batch of fragments-molecule pairs.
Parameters
----------
net : torch.nn.Module
Agent
src : torch.Tensor
Fragments-molecule pairs
Returns
-------
frags : list
List of fragments
smiles : list
List of SMILES
loss : torch.Tensor
Loss of the agent
"""
src, trg = src[0].to(self.device), src[1].to(self.device)
frags = [self.agent.voc_trg.decode(s, is_tk=False) for s in src]
smiles = [self.agent.voc_trg.decode(s, is_tk=False) for s in trg]
loss = net(src, trg)
return frags, smiles, loss
[docs] def sampleEncodedPairsToLoader(self, net, loader):
"""
Sample new fragments-molecule pairs from a data loader.
Parameters
----------
net : torch.nn.Module
Agent
loader : torch.utils.data.DataLoader
Data loader for original fragments-molecule pairs
Returns
-------
torch.utils.data.DataLoader
Data loader for sampled fragments-molecule pairs
"""
srcs, trgs = [], []
for _, src in tqdm(loader, desc='Iterating over training batches', leave=False):
with torch.no_grad():
trg = net(src.to(self.device))
trgs.append(trg.detach().cpu())
srcs.append(src.detach().cpu())
trgs = torch.cat(trgs, dim=0)
srcs = torch.cat(srcs, dim=0)
return DataLoader(TensorDataset(srcs, trgs), batch_size=self.batchSize, shuffle=True, drop_last=False)