Source code for drugex.training.generators.graph_transformer

import tempfile
import torch
import torch.nn as nn
from tqdm.auto import tqdm

from drugex import DEFAULT_GPUS, DEFAULT_DEVICE
from drugex.data.fragments import GraphFragmentEncoder, FragmentCorpusEncoder
from drugex.data.datasets import GraphFragDataSet
from drugex.molecules.converters.dummy_molecules import dummyMolsFromFragments
from drugex.training.generators.utils import PositionwiseFeedForward, SublayerConnection, PositionalEncoding, tri_mask
from drugex.training.generators.interfaces import FragGenerator
from drugex.utils import ScheduledOptim
from torch import optim


[docs]class Block(nn.Module): def __init__(self, d_model, n_head, d_inner): super(Block, self).__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.pffn = PositionwiseFeedForward(d_model, d_inner) self.connector = nn.ModuleList([SublayerConnection(d_model) for _ in range(2)])
[docs] def forward(self, x, key_mask=None, attn_mask=None): x = self.connector[0](x, lambda x: self.attn(x, x, x, key_mask, attn_mask=attn_mask)[0]) x = self.connector[1](x, self.pffn) return x
[docs]class AtomLayer(nn.Module): def __init__(self, d_model=512, n_head=8, d_inner=1024, n_layer=12): super(AtomLayer, self).__init__() self.n_layer = n_layer self.d_model = d_model self.n_head = n_head self.blocks = nn.ModuleList([Block(self.d_model, self.n_head, d_inner=d_inner) for _ in range(self.n_layer)])
[docs] def forward(self, x: torch.Tensor, key_mask=None, attn_mask=None): for block in self.blocks: x = block(x, key_mask=key_mask, attn_mask=attn_mask) return x
[docs]class GraphTransformer(FragGenerator): """ Graph Transformer for molecule generation from fragments """ def __init__(self, voc_trg, d_emb=512, d_model=512, n_head=8, d_inner=1024, n_layer=12, pad_idx=0, device=DEFAULT_DEVICE, use_gpus=DEFAULT_GPUS): super(GraphTransformer, self).__init__(device=device, use_gpus=use_gpus) self.mol_type = 'graph' self.voc_trg = voc_trg self.pad_idx = pad_idx self.d_model = d_model self.n_grows = voc_trg.max_len - voc_trg.n_frags - 1 self.n_frags = voc_trg.n_frags + 1 self.d_emb = d_emb self.emb_word = nn.Embedding(voc_trg.size * 4, self.d_emb, padding_idx=pad_idx) self.emb_atom = nn.Embedding(voc_trg.size, self.d_emb, padding_idx=pad_idx) self.emb_loci = nn.Embedding(self.n_grows, self.d_emb) self.emb_site = PositionalEncoding(self.d_emb, max_len=self.n_grows*self.n_grows) self.attn = AtomLayer(d_model=d_model, n_head=n_head, d_inner=d_inner, n_layer=n_layer) self.rnn = nn.GRUCell(self.d_model, self.d_model) self.prj_atom = nn.Linear(d_emb, self.voc_trg.size) self.prj_bond = nn.Linear(d_model, 4) self.prj_loci = nn.Linear(d_model, self.n_grows) self.init_states() self.optim = ScheduledOptim( optim.Adam(self.parameters(), betas=(0.9, 0.98), eps=1e-9), 0.1, d_model) # self.optim = optim.Adam(self.parameters(), lr=1e-4) self.model_name = 'GraphTransformer'
[docs] def init_states(self): """ Initialize model parameters Notes: ----- Xavier initialization for all parameters except for the embedding layer """ for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) self.attachToGPUs(self.gpus)
[docs] def attachToGPUs(self, gpus): """ Attach model to GPUs Parameters: ---------- gpus: `tuple` A tuple of GPU ids to use Returns: ------- None """ self.gpus = gpus self.to(self.device)
[docs] def forward(self, src, is_train=False): """ Forward pass Parameters: ---------- src: `torch.Tensor` Input tensor of shape [batch_size, 80, 5] (transpose of the encoded graphs as drawn in the paper) is_train: `bool` Whether the model is in training mode Returns: ------- TODO : fill outputs """ if is_train: # Return loss src, trg = src[:, :-1, :], src[:, 1:, :] batch, sqlen, _ = src.shape triu = tri_mask(src[:, :, 0]) # dec - atom environment emb = self.emb_word(src[:, :, 3] + src[:, :, 0] * 4) emb += self.emb_site(src[:, :, 1] * self.n_grows + src[:, :, 2]) dec = self.attn(emb.transpose(0, 1), attn_mask=triu) dec = dec.transpose(0, 1).reshape(batch * sqlen, -1) out_atom = self.prj_atom(dec).log_softmax(dim=-1).view(batch, sqlen, -1) out_atom = out_atom.gather(2, trg[:, :, 0].unsqueeze(2)) atom = self.emb_atom(trg[:, :, 0]).reshape(batch * sqlen, -1) dec = self.rnn(atom, dec) out_bond = self.prj_bond(dec).log_softmax(dim=-1).view(batch, sqlen, -1) out_bond = out_bond.gather(2, trg[:, :, 3].unsqueeze(2)) word = self.emb_word(trg[:, :, 3] + trg[:, :, 0] * 4) word = word.reshape(batch * sqlen, -1) dec = self.rnn(word, dec) out_prev = self.prj_loci(dec).log_softmax(dim=-1).view(batch, sqlen, -1) out_prev = out_prev.gather(2, trg[:, :, 2].unsqueeze(2)) curr = self.emb_loci(trg[:, :, 2]).reshape(batch * sqlen, -1) dec = self.rnn(curr, dec) out_curr = self.prj_loci(dec).log_softmax(dim=-1).view(batch, sqlen, -1) out_curr = out_curr.gather(2, trg[:, :, 1].unsqueeze(2)) out = [out_atom, out_curr, out_prev, out_bond] else: # Return encoded molecules is_end = torch.zeros(len(src)).bool().to(src.device) exists = torch.zeros(len(src), self.n_grows, self.n_grows).long().to(src.device) vals_max = torch.zeros(len(src), self.n_grows).long().to(src.device) frg_ids = torch.zeros(len(src), self.n_grows).long().to(src.device) order = torch.LongTensor(range(len(src))).to(src.device) curr = torch.zeros(len(src)).long().to(src.device) - 1 blank = torch.LongTensor(len(src)).to(src.device).fill_(self.voc_trg.tk2ix['*']) single = torch.ones(len(src)).long().to(src.device) voc_mask = self.voc_trg.masks.to(src.device) # The part of growing for step in range(1, self.n_grows): if is_end.all(): src[:, step, :] = 0 continue data = src[:, :step, :] triu = tri_mask(data[:, :, 0]) emb = self.emb_word(data[:, :, 3] + data[:, :, 0] * 4) emb += self.emb_site(data[:, :, 1] * self.n_grows + data[:, :, 2]) dec = self.attn(emb.transpose(0, 1), attn_mask=triu) dec = dec[-1, :, :] grow = src[:, step, 4] == 0 mask = voc_mask.repeat(len(src), 1) < 0 if step <= 2: mask[:, -1] = True else: judge = (vals_rom == 0) | (exists[order, curr, :] != 0) judge[order, curr] = True judge = judge.all(dim=1) | (vals_rom[order, curr] == 0) mask[judge, -1] = True mask[:, 1] = True mask[is_end, 1:] = True out_atom = self.prj_atom(dec).softmax(dim=-1) atom = out_atom.masked_fill(mask, 0).multinomial(1).view(-1) src[grow, step, 0] = atom[grow] atom = src[:, step, 0] is_end |= (atom == 0) & grow num = (vals_max > 0).sum(dim=1) vals_max[order, num] = voc_mask[atom] vals_rom = vals_max - exists.sum(dim=1) bud = atom != self.voc_trg.tk2ix['*'] curr += bud curr[is_end] = 0 src[:, step, 1] = curr exist = exists[order, curr, :] != 0 mask = torch.zeros(len(src), 4).bool().to(src.device) for i in range(1, 4): judge = (vals_rom < i) | exist judge[order, curr] = True mask[:, i] = judge.all(dim=1) | (vals_rom[order, curr] < i) mask[:, 0] = False if step == 1 else True mask[is_end, 0] = False mask[is_end, 1:] = True atom_emb = self.emb_atom(atom) dec = self.rnn(atom_emb, dec) out_bond = self.prj_bond(dec).softmax(dim=-1) bond = out_bond.masked_fill(mask, 0).multinomial(1).view(-1) src[grow, step, 3] = bond[grow] bond = src[:, step, 3] mask = (vals_max == 0) | exist | (vals_rom < bond.unsqueeze(-1)) mask[order, curr] = True if step <= 2: mask[:, 0] = False mask[is_end, 0] = False mask[is_end, 1:] = True word_emb = self.emb_word(atom * 4 + bond) dec = self.rnn(word_emb, dec) prev_out = self.prj_loci(dec).softmax(dim=-1) prev = prev_out.masked_fill(mask, 0).multinomial(1).view(-1) src[grow, step, 2] = prev[grow] prev = src[:, step, 2] for i in range(len(src)): if not grow[i]: frg_ids[i, curr[i]] = src[i, step, -1] elif bud[i]: frg_ids[i, curr[i]] = frg_ids[i, prev[i]] obj = frg_ids[i, curr[i]].clone() ix = frg_ids[i, :] == frg_ids[i, prev[i]] frg_ids[i, ix] = obj exists[order, curr, prev] = bond exists[order, prev, curr] = bond vals_rom = vals_max - exists.sum(dim=1) is_end |= (vals_rom == 0).all(dim=1) # The part of connecting src[:, -self.n_frags, 1:] = 0 src[:, -self.n_frags, 0] = self.voc_trg.tk2ix['GO'] is_end = torch.zeros(len(src)).bool().to(src.device) for step in range(self.n_grows + 1, self.voc_trg.max_len): data = src[:, :step, :] triu = tri_mask(data[:, :, 0]) emb = self.emb_word(data[:, :, 3] + data[:, :, 0] * 4) emb += self.emb_site(data[:, :, 1] * self.n_grows + data[:, :, 2]) dec = self.attn(emb.transpose(0, 1), attn_mask=triu) vals_rom = vals_max - exists.sum(dim=1) frgs_rom = torch.zeros(len(src), 8).long().to(src.device) for i in range(1, 8): ix = frg_ids != i rom = vals_rom.clone() rom[ix] = 0 frgs_rom[:, i] = rom.sum(dim=1) is_end |= (vals_rom == 0).all(dim=1) is_end |= (frgs_rom != 0).sum(dim=1) <= 1 mask = (vals_rom < 1) | (vals_max == 0) mask[is_end, 0] = False atom_emb = self.emb_word(blank * 4 + single) dec = self.rnn(atom_emb, dec[-1, :, :]) out_prev = self.prj_loci(dec).softmax(dim=-1) prev = out_prev.masked_fill(mask, 0).multinomial(1).view(-1) same = frg_ids == frg_ids[order, prev].view(-1, 1) exist = exists[order, prev] != 0 mask = (vals_rom < 1) | exist | (vals_max == 0) | same mask[is_end, 0] = False prev_emb = self.emb_loci(prev) dec = self.rnn(prev_emb, dec) out_curr = self.prj_loci(dec).softmax(dim=-1) curr = out_curr.masked_fill(mask, 0).multinomial(1).view(-1) src[:, step, 3] = single src[:, step, 2] = prev src[:, step, 1] = curr src[:, step, 0] = blank src[is_end, step, :] = 0 for i in range(len(src)): obj = frg_ids[i, curr[i]].clone() ix = frg_ids[i, :] == frg_ids[i, prev[i]] frg_ids[i, ix] = obj exists[order, src[:, step, 1], src[:, step, 2]] = src[:, step, 3] exists[order, src[:, step, 2], src[:, step, 1]] = src[:, step, 3] out = src return out
[docs] def trainNet(self, loader, epoch, epochs): """ Train the network for one epoch Parameters ---------- loader : torch.utils.data.DataLoader The data loader for the training set epoch : int The current epoch epochs : int The total number of epochs Returns ------- loss : float The training loss of the epoch """ net = nn.DataParallel(self, device_ids=self.gpus) total_steps = len(loader) current_step = 0 for src in tqdm(loader, desc='Iterating over training batches', leave=False): src = src.to(self.device) self.optim.zero_grad() loss = net(src, is_train=True) loss = sum([-l.mean() for l in loss]) loss.backward() self.optim.step() current_step += 1 self.monitor.saveProgress(current_step, epoch, total_steps, epochs, loss=loss.item()) return loss.item()
[docs] def validateNet(self, loader, evaluator=None, no_multifrag_smiles=True, n_samples=None): """ Validate the network Parameters ---------- loader : torch.utils.data.DataLoader A dataloader object to iterate over the validation data evaluator : Evaluator An evaluator object to evaluate the generated SMILES no_multifrag_smiles : bool If `True`, only single-fragment SMILES are considered valid Returns ------- valid_metrics : dict Dictionary containing the validation metrics scores : pandas.DataFrame DataFrame containing Smiles, frags and the scores for each SMILES Notes ----- The validation metrics are: - valid_ratio: the ratio of valid SMILES - accurate_ratio: the ratio of SMILES that are valid and have the desired fragments - loss_valid: the validation loss """ valid_metrics = {} net = nn.DataParallel(self, device_ids=self.gpus) pbar = tqdm(loader, desc='Iterating over validation batches', leave=False) smiles, frags = self.sample(pbar) scores = self.evaluate(smiles, frags, evaluator=evaluator, no_multifrag_smiles=no_multifrag_smiles) scores['SMILES'] = smiles scores['Frags'] = frags valid_metrics['valid_ratio'] = scores.Valid.mean() valid_metrics['accurate_ratio'] = scores.Accurate.mean() with torch.no_grad(): valid_metrics['loss_valid'] = sum( [ sum([-l.float().mean().item() for l in net(src, is_train=True)]) for src in loader ] ) return valid_metrics, scores
[docs] def sample(self, loader): """ Sample SMILES from the network Parameters ---------- loader : torch.utils.data.DataLoader The data loader for the input fragments Returns ------- smiles : list List of SMILES frags : list List of fragments """ net = nn.DataParallel(self, device_ids=self.gpus) frags, smiles = [], [] with torch.no_grad(): for src in loader: trg = net(src.to(self.device)) f, s = self.voc_trg.decode(trg) frags += f smiles += s return smiles, frags
[docs] def loaderFromFrags(self, frags, batch_size=32, n_proc=1): """ Encode the input fragments and create a dataloader object Parameters: ---------- frags: `list` A list of input fragments (in SMILES format) batch_size: `int` Batch size for the dataloader n_proc: `int` Number of processes to use for encoding the fragments Returns: ------- loader: `torch.utils.data.DataLoader` A dataloader object to iterate over the input fragments """ # Encode the input fragments encoder = FragmentCorpusEncoder( fragmenter=dummyMolsFromFragments(), encoder=GraphFragmentEncoder( self.voc_trg ), n_proc=n_proc ) out_data = GraphFragDataSet(tempfile.NamedTemporaryFile().name) encoder.apply(frags, encodingCollectors=[out_data]) loader = out_data.asDataLoader(batch_size, n_samples=batch_size) return loader
[docs] def decodeLoaders(self, src, trg): return self.voc_trg.decode(trg)
[docs] def iterLoader(self, loader): return loader