Source code for drugex.training.generators.sequence_transformer
import tempfile
import torch
import torch.nn as nn
from torch import optim
from torch.nn.init import kaiming_normal_
from tqdm.auto import tqdm
from drugex import DEFAULT_DEVICE, DEFAULT_GPUS
from drugex.data.fragments import SequenceFragmentEncoder, FragmentCorpusEncoder
from drugex.data.datasets import SmilesFragDataSet
from drugex.molecules.converters.dummy_molecules import dummyMolsFromFragments
from drugex.training.generators.utils import PositionalEmbedding, PositionwiseFeedForward, SublayerConnection, pad_mask, tri_mask
from drugex.training.generators.interfaces import FragGenerator
from drugex.utils import ScheduledOptim
[docs]class Block(nn.Module):
def __init__(self, d_model, n_head, d_inner):
super(Block, self).__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.pffn = PositionwiseFeedForward(d_model, d_inner)
self.connector = nn.ModuleList([SublayerConnection(d_model) for _ in range(2)])
[docs] def forward(self, x, key_mask=None, atn_mask=None):
x = self.connector[0](x, lambda x: self.attn(x, x, x, key_mask, attn_mask=atn_mask)[0])
x = self.connector[1](x, self.pffn)
return x
[docs]class GPT2Layer(nn.Module):
def __init__(self, voc, d_emb=512, d_model=512, n_head=12, d_inner=1024, n_layer=12, pad_idx=0):
super(GPT2Layer, self).__init__()
self.n_layer = n_layer
self.d_emb = d_emb
self.d_model = d_model
self.n_head = n_head
self.voc = voc
self.pad_idx = pad_idx
self.token_emb = nn.Embedding(voc.size, self.d_emb, padding_idx=pad_idx)
self.posit_emb = PositionalEmbedding(self.d_emb, max_len=voc.max_len + voc.max_len)
self.blocks = nn.ModuleList([Block(self.d_emb, self.n_head, d_inner=d_inner) for _ in range(self.n_layer)])
self.layer_norm = nn.LayerNorm(self.d_emb)
self.word_prj = nn.Linear(self.d_emb, self.voc.size)
kaiming_normal_(self.word_prj.weight, nonlinearity="linear")
[docs] def forward(self, input: torch.Tensor, key_mask=None, atn_mask=None):
hidden_states = self.posit_emb(input) + self.token_emb(input)
for block in self.blocks:
hidden_states = block(hidden_states, key_mask=key_mask, atn_mask=atn_mask)
hidden_states = self.word_prj(hidden_states)
return hidden_states
[docs]class SequenceTransformer(FragGenerator):
"""
Sequence Transformer for molecule generation from fragments
"""
def __init__(self, voc_trg, d_emb=512, d_model=512, n_head=8, d_inner=1024, n_layer=12, pad_idx=0, device=DEFAULT_DEVICE, use_gpus=DEFAULT_GPUS):
super(SequenceTransformer, self).__init__(device=device, use_gpus=use_gpus)
self.mol_type = 'smiles'
self.voc_trg = voc_trg
self.pad_idx = pad_idx
self.gpt2 = GPT2Layer(self.voc_trg, d_emb=d_emb, d_model=d_model,
n_head=n_head, d_inner=d_inner, n_layer=n_layer,
pad_idx=pad_idx)
self.init_states()
self.optim = ScheduledOptim(
optim.Adam(self.parameters(), betas=(0.9, 0.98), eps=1e-9), 0.5, d_model)
# self.optim = optim.Adam(self.parameters(), lr=1e-4)
self.model_name = 'SequenceTransformer'
# def init_states(self):
# """
# Initialize model parameters
# Notes:
# -----
# Xavier initialization for all parameters except for the embedding layer
# """
# for p in self.parameters():
# if p.dim() > 1:
# nn.init.xavier_uniform_(p)
# self.attachToGPUs(self.gpus)
# def attachToGPUs(self, gpus):
# """
# Attach model to GPUs
# Parameters:
# ----------
# gpus: `tuple`
# A tuple of GPU ids to use
# Returns:
# -------
# None
# """
# self.gpus = gpus
# self.to(self.device)
[docs] def forward(self, src, trg=None):
"""
Forward pass of the model
Parameters:
----------
src: `torch.Tensor`
TODO: check that the shape is correct
Source tensor of shape [batch_size, 200]
trg: `torch.Tensor`
Target tensor of shape [batch_size, 200]
Returns:
-------
TODO: fill outputs
"""
if trg is not None:
input = torch.cat([src, trg], dim=1)
key_mask = pad_mask(input, self.pad_idx)
atn_mask = tri_mask(input)
start, end = src.size(1) - 1, -1
input = input.transpose(0, 1)
dec = self.gpt2(input, key_mask=key_mask, atn_mask=atn_mask)[start:end, :, :]
dec = dec.transpose(0, 1).log_softmax(dim=-1)
out = dec.gather(2, trg.unsqueeze(2)).squeeze(2)
else:
seq_len = self.voc_trg.max_len + self.voc_trg.max_len
out = torch.zeros(len(src), seq_len).long().to(src.device)
out[:, :src.size(1)] = src
is_end = torch.zeros(len(src)).bool().to(src.device)
for step in range(self.voc_trg.max_len): # decode up to max length
input = out[:, :src.size(1)+step]
key_mask = pad_mask(input, self.pad_idx)
atn_mask = tri_mask(input)
dec = self.gpt2(input.transpose(0, 1), key_mask=key_mask, atn_mask=atn_mask)
x = dec.softmax(dim=-1)[-1, :, :].multinomial(1).view(-1)
# prev = dec[:, -1, :].argmax(-1)
x[is_end] = self.voc_trg.tk2ix['_']
is_end |= x == self.voc_trg.tk2ix['EOS']
out[:, src.size(1)+step] = x
if is_end.all(): break
out = out[:, self.voc_trg.max_len:].detach()
return out
[docs] def trainNet(self, loader, epoch, epochs):
"""
Train the model for one epoch
Parameters:
----------
loader: `torch.utils.data.DataLoader`
A dataloader object to iterate over the training data
epoch: `int`
Current epoch number
epochs: `int`
Total number of epochs
Returns:
-------
loss: `float`
The loss value for the current epoch
"""
net = nn.DataParallel(self, device_ids=self.gpus)
total_steps = len(loader)
current_step = 0
for src, trg in tqdm(loader, desc='Iterating over training batches', leave=False):
src, trg = src.to(self.device), trg.to(self.device)
self.optim.zero_grad()
loss = net(src, trg)
loss = -loss.mean()
loss.backward()
self.optim.step()
current_step += 1
self.monitor.saveProgress(current_step, epoch, total_steps, epochs, loss.item())
return loss.item()
[docs] def validateNet(self, loader, evaluator=None, no_multifrag_smiles=True, n_samples=None):
"""
Validate the model
Parameters:
----------
loader: `torch.utils.data.DataLoader`
A dataloader object to iterate over the validation data
evaluator: `Evaluator`
An evaluator object to evaluate the generated SMILES
no_multifrag_smiles: `bool`
If `True`, only single-fragment SMILES are considered valid
Returns:
-------
valid_metrics: `dict`
A dictionary containing the validation metrics
scores: `pandas.DataFrame`
DataFrame containing Smiles, frags and the scores for each SMILES
Notes:
-----
The validation metrics are:
- valid_ratio: ratio of valid SMILES
- accurate_ratio: ratio of SMILES that are valid and have the desired fragments
- loss_valid: loss on the validation set
"""
valid_metrics = {}
net = nn.DataParallel(self, device_ids=self.gpus)
pbar = tqdm(loader, desc='Iterating over validation batches', leave=False)
smiles, frags = self.sample(pbar)
scores = self.evaluate(smiles, frags, evaluator=evaluator, no_multifrag_smiles=no_multifrag_smiles)
scores['SMILES'] = smiles
scores['Frags'] = frags
valid_metrics['valid_ratio'] = scores.Valid.mean()
valid_metrics['accurate_ratio'] = scores.Accurate.mean()
with torch.no_grad():
valid_metrics['loss_valid'] = sum( [ sum([-l.mean().item() for l in net(src, trg)]) for src, trg in loader ] )
return valid_metrics, scores
[docs] def sample(self, loader):
"""
Sample SMILES from the model
Parameters:
----------
loader: `torch.utils.data.DataLoader`
A dataloader object to iterate over the input fragments
Returns:
-------
smiles: `list`
A list of sampled SMILES
frags: `list`
A list of input fragments
"""
net = nn.DataParallel(self, device_ids=self.gpus)
frags, smiles = [], []
with torch.no_grad():
for src, _ in loader:
trg = net(src.to(self.device))
smiles += [self.voc_trg.decode(s, is_tk=False) for s in trg]
frags += [self.voc_trg.decode(s, is_tk=False) for s in src]
return smiles, frags
[docs] def loaderFromFrags(self, frags, batch_size=32, n_proc=1):
"""
Encode the input fragments and create a dataloader object
Parameters:
----------
frags: `list`
A list of input fragments (in SMILES format)
batch_size: `int`
Batch size for the dataloader
n_proc: `int`
Number of processes to use for encoding the fragments
Returns:
-------
loader: `torch.utils.data.DataLoader`
A dataloader object to iterate over the input fragments
"""
# Encode the input fragments
encoder = FragmentCorpusEncoder(
fragmenter=dummyMolsFromFragments(),
encoder=SequenceFragmentEncoder(
self.voc_trg
),
n_proc=n_proc
)
out_data = SmilesFragDataSet(tempfile.NamedTemporaryFile().name)
encoder.apply(frags, encodingCollectors=[out_data])
loader = out_data.asDataLoader(batch_size, n_samples=batch_size)
return loader
[docs] def decodeLoaders(self, src, trg):
new_smiles = [self.voc_trg.decode(s, is_tk=False) for s in trg]
new_frags = [self.voc_trg.decode(s, is_tk=False) for s in src]
return new_frags, new_smiles