"""
vocabulary
Created by: Martin Sicho
On: 26.04.22, 13:16
"""
import re
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from drugex.data.corpus.interfaces import SequenceVocabulary, Vocabulary
from drugex.logs import logger
from drugex.molecules.converters.standardizers import CleanSMILES
[docs]class VocSmiles(SequenceVocabulary):
"""The class for handling encoding/decoding from SMILES to an array of indices for the main SMILES-based models (`GPT2Model` and `RNN`)"""
defaultWords = ('#','%','(',')','-','0','1','2','3','4','5','6','7','8','9','=','B','C','F','I','L','N','O','P','R','S','[Ag-3]','[As+]','[As]','[B-]','[BH-]','[BH2-]','[BH3-]','[B]','[C+]','[C-]','[CH-]','[CH2]','[CH]','[I+]','[IH2]','[N+]','[N-]','[NH+]','[NH-]','[NH2+]','[N]','[O+]','[O-]','[OH+]','[O]','[P+]','[PH]','[S+]','[S-]','[SH+]','[SH2]','[SH]','[Se+]','[SeH]','[Se]','[SiH2]','[SiH]','[Si]','[Te]','[b-]','[c+]','[c-]','[cH-]','[n+]','[n-]','[nH+]','[nH]','[o+]','[s+]','[se+]','[se]','[te+]',"[te]",'b','c','n','o','p','s'
)
def __init__(self, encode_frags, words=defaultWords, max_len=100, min_len=10):
super().__init__(encode_frags, words, min_len=min_len, max_len=max_len)
[docs] def encode(self, tokens, frags=None):
"""
Takes a list of tokens (eg '[NH]') and encodes to array of indices
Args:
input: a list of SMILES sequence represented as a series of tokens
Returns:
output (torch.LongTensor): a long tensor containing all the indices of given tokens.
"""
output = torch.zeros(len(tokens), self.max_len).long()
for i, seq in enumerate(tokens):
for j, char in enumerate(seq):
output[i, j] = self.tk2ix[char]
return output
[docs] def decode(self, tensor, is_tk=True, is_smiles=True):
"""Takes an array of indices and returns the corresponding SMILES
Args:
tensor(torch.LongTensor): a long tensor containing all the indices of given tokens.
Returns:
smiles (str): a decoded smiles sequence.
"""
tokens = []
for token in tensor:
if not is_tk:
token = self.ix2tk[int(token)]
if token == 'EOS': break
if token in self.control: continue
tokens.append(token)
seqs = "".join(tokens)
if is_smiles:
seqs = self.parseDecoded(seqs)
else:
seqs = seqs.replace('|', '')
return seqs
[docs] def parseDecoded(self, smiles):
return smiles.replace('L', 'Cl').replace('R', 'Br')
[docs] def splitSequence(self, smile):
"""Takes a SMILES and return a list of characters/tokens
Args:
smile (str): a decoded smiles sequence.
Returns:
tokens (List): a list of tokens decoded from the SMILES sequence.
"""
regex = '(\[[^\[\]]{1,6}\])'
smile = smile.replace('Cl', 'L').replace('Br', 'R')
tokens = []
for word in re.split(regex, smile):
if word == '' or word is None: continue
if word.startswith('['):
tokens.append(word)
else:
for i, char in enumerate(word):
tokens.append(char)
return tokens + ['EOS']
[docs] @staticmethod
def fromFile(path, encode_frags, min_len=10, max_len=100):
"""Takes a file containing \n separated characters to initialize the vocabulary"""
with open(path, 'r') as f:
words = f.read().split()
return VocSmiles(encode_frags, words, max_len=max_len, min_len=min_len)
[docs] def calc_voc_fp(self, smiles, prefix=None):
fps = np.zeros((len(smiles), self.max_len), dtype=np.long)
for i, smile in enumerate(smiles):
smile = CleanSMILES()(smile)
token = self.splitSequence(smile)
if prefix is not None: token = [prefix] + token
if len(token) > self.max_len: continue
if {'C', 'c'}.isdisjoint(token): continue
if not {'[Na]', '[Zn]'}.isdisjoint(token): continue
fps[i, :] = self.encode(token)
return fps
[docs]class VocNonGPT(VocSmiles):
"""
Modified version of `VocSmiles` adjusted for the legacy sequence models (`Seq2Seq` and `EncDec`).
"""
def __init__(self, words, src_len=1000, trg_len=100, max_len=100, min_len=10):
super(VocNonGPT, self).__init__(False, words, max_len=max_len, min_len=min_len)
self.src_len = src_len
self.trg_len = trg_len
[docs] def encode(self, input, is_smiles=True):
"""Takes a list of characters (eg '[NH]') and encodes to array of indices"""
seq_len = self.trg_len if is_smiles else self.src_len
output = torch.zeros(len(input), seq_len).long()
for i, seq in enumerate(input):
for j, char in enumerate(seq):
output[i, j] = self.tk2ix[char] if is_smiles else self.tk2ix['|' + char]
return output
[docs] def decode(self, matrix, is_smiles=True, is_tk=False):
"""
Takes an array of indices and returns the corresponding SMILES.
"""
chars = super(VocNonGPT, self).decode(matrix, is_tk)
seqs = "".join(chars)
if is_smiles:
seqs = self.parseDecoded(seqs)
else:
seqs = seqs.replace('|', '')
return seqs
[docs] @staticmethod
def fromFile(path, src_len=1000, trg_len=100, max_len=100, min_len=10):
"""Takes a file containing \n separated characters to initialize the vocabulary"""
with open(path, 'r') as f:
words = f.read().split()
return VocNonGPT(words, src_len=src_len, trg_len=trg_len, max_len=max_len, min_len=min_len)
[docs]class VocGraph(Vocabulary):
defaultWords=('2O','3O+','1O-','4C','3C+','3C-','3N','4N+','2N-','1Cl','2S','6S','4S','3S+','5S+','1S-','1F','1I','5I','2I+','1Br','5P','3P','4P+','2Se','6Se','4Se','3Se+','4Si','3B','4B-','5As','3As','4As+','2Te','4Te','3Te+',)
def __init__(self, words=defaultWords, max_len=80, n_frags=4):
super().__init__(words)
self.control = ('EOS', 'GO')
words = [x for x in words if x not in self.control]
words_unique = []
for word in words:
if word not in words_unique:
words_unique.append(word)
words = words_unique
self.n_frags = n_frags
self.max_len = max_len
self.tk2ix = {'EOS': 0, 'GO': 1}
self.ix2nr = {0: 0, 1: 0}
self.ix2ch = {0: 0, 1: 0}
self.E = {0: '', 1: '+', -1: '-'}
# init words
self.words = []
self.wordsParsed = [self.parseWord(word) for word in words]
self.words = list(self.control) + list(words)
if '*' not in words:
self.words.append('*')
self.wordsParsed.append(('*',0,0,0,'*'))
self.size = len(self.words)
self.masks = torch.zeros(len(self.wordsParsed) + len(self.control)).long()
for i,item in enumerate(self.wordsParsed):
self.masks[i + len(self.control)] = item[1]
ix = i + len(self.control)
self.tk2ix[item[4]] = ix
self.ix2nr[ix] = item[3]
self.ix2ch[ix] = item[2]
assert len(set(self.words)) == len(self.words)
[docs] @staticmethod
def parseWord(word):
if word == '*':
return '*',0,0,0,'*'
valence = re.search(r'[0-9]', word).group(0)
charge = re.search(r'[+-]', word)
charge_num = 0
if charge:
charge = charge.group(0)
charge_num = 1 if charge == '+' else -1
else:
charge = ''
element = re.search(r'[a-zA-Z]+', word).group(0)
return element + charge, int(valence), charge_num, Chem.Atom(element).GetAtomicNum(), word
[docs] @staticmethod
def fromFile(path, word_col='Word', max_len=80, n_frags=4):
df = pd.read_table(path)
return VocGraph.fromDataFrame(df, word_col, max_len=80, n_frags=4)
[docs] @staticmethod
def fromDataFrame(df, word_col='Word', max_len=80, n_frags=4):
return VocGraph(df[word_col].tolist(), max_len=max_len, n_frags=n_frags)
[docs] def toFile(self, path):
self.toDataFrame().to_csv(path, index=False, sep='\t')
[docs] def toDataFrame(self):
return pd.DataFrame(self.wordsParsed, columns=['Ele', 'Val', 'Ch', 'Nr', 'Word'])
[docs] def get_atom_tk(self, atom):
sb = atom.GetSymbol() + self.E[atom.GetFormalCharge()]
val = atom.GetExplicitValence() + atom.GetImplicitValence()
tk = str(val) + sb
return self.tk2ix[tk]
[docs] def encode(self, smiles, subs=None):
if not subs:
raise RuntimeError(f'Fragments must be specified, got {subs} instead')
output = np.zeros([len(smiles), self.max_len - self.n_frags - 1, 5], dtype=np.compat.long)
connect = np.zeros([len(smiles), self.n_frags + 1, 5], dtype=np.compat.long)
for i, s in enumerate(smiles):
mol = Chem.MolFromSmiles(s)
sub = Chem.MolFromSmiles(subs[i])
# Chem.Kekulize(sub)
sub_idxs = mol.GetSubstructMatches(sub)
for sub_idx in sub_idxs:
sub_bond = [mol.GetBondBetweenAtoms(
sub_idx[b.GetBeginAtomIdx()],
sub_idx[b.GetEndAtomIdx()]).GetIdx() for b in sub.GetBonds()]
sub_atom = [mol.GetAtomWithIdx(ix) for ix in sub_idx]
split_bond = {b.GetIdx() for a in sub_atom for b in a.GetBonds() if b.GetIdx() not in sub_bond}
single = sum([int(mol.GetBondWithIdx(b).GetBondType()) for b in split_bond])
if single == len(split_bond): break
frags = Chem.FragmentOnBonds(mol, list(split_bond))
Chem.MolToSmiles(frags)
rank = eval(frags.GetProp('_smilesAtomOutputOrder'))
mol_idx = list(sub_idx) + [idx for idx in rank if idx not in sub_idx and idx < mol.GetNumAtoms()]
frg_idx = [i+1 for i, f in enumerate(Chem.GetMolFrags(sub)) for _ in f]
Chem.Kekulize(mol)
m, n, c = [(self.tk2ix['GO'], 0, 0, 0, 1)], [], [(self.tk2ix['GO'], 0, 0, 0, 0)]
mol2sub = {ix: i for i, ix in enumerate(mol_idx)}
for j, idx in enumerate(mol_idx):
atom = mol.GetAtomWithIdx(idx)
bonds = sorted(atom.GetBonds(), key=lambda x: mol2sub[x.GetOtherAtomIdx(idx)])
bonds = [b for b in bonds if j > mol2sub[b.GetOtherAtomIdx(idx)]]
n_split = sum([1 if b.GetIdx() in split_bond else 0 for b in bonds])
tk = self.get_atom_tk(atom)
for k, bond in enumerate(bonds):
ix2 = mol2sub[bond.GetOtherAtomIdx(idx)]
is_split = bond.GetIdx() in split_bond
if idx in sub_idx:
is_connect = is_split
elif len(bonds) == 1:
is_connect = False
elif n_split == len(bonds):
is_connect = is_split and k != 0
else:
is_connect = False
if bond.GetIdx() in sub_bond:
bin, f = m, frg_idx[j]
elif is_connect:
bin, f = c, 0
else:
bin, f = n, 0
if bond.GetIdx() in sub_bond or not is_connect:
tk2 = tk
tk = self.tk2ix['*']
else:
tk2 = self.tk2ix['*']
bin.append((tk2, j, ix2, int(bond.GetBondType()), f))
if tk != self.tk2ix['*']:
bin, f = (m, frg_idx[j]) if idx in sub_idx else (n, f)
bin.append((tk, j, j, 0, f))
output[i, :len(m+n), :] = m+n
if len(c) > 0:
connect[i, :len(c)] = c
return np.concatenate([output, connect], axis=1)
[docs] def decode(self, matrix):
frags, smiles = [], []
for m, adj in enumerate(matrix):
emol = Chem.RWMol()
esub = Chem.RWMol()
try:
for atom, curr, prev, bond, frag in adj:
atom, curr, prev, bond, frag = int(atom), int(curr), int(prev), int(bond), int(frag)
if atom == self.tk2ix['EOS']: continue
if atom == self.tk2ix['GO']: continue
if atom != self.tk2ix['*']:
a = Chem.Atom(self.ix2nr[atom])
a.SetFormalCharge(self.ix2ch[atom])
emol.AddAtom(a)
if frag != 0: esub.AddAtom(a)
if bond != 0:
b = Chem.BondType(bond)
emol.AddBond(curr, prev, b)
if frag != 0: esub.AddBond(curr, prev, b)
Chem.SanitizeMol(emol)
Chem.SanitizeMol(esub)
except Exception as e:
logger.error(f'Error while decoding: {adj}')
logger.error(e)
frags.append(Chem.MolToSmiles(esub))
smiles.append(Chem.MolToSmiles(emol))
return frags, smiles