"""
interfaces
Created by: Martin Sicho
On: 26.04.22, 13:12
"""
from abc import ABC, abstractmethod
from drugex.logs import logger
from drugex.molecules.interfaces import MolSupplier
[docs]class Vocabulary(ABC):
"""
Definition of the vocabulary interface. All vocabularies contain "words" that are used for encoding and decoding molecules.
"""
def __init__(self, words):
self.words = words
def __add__(self, other):
return type(self)(other.words + self.words)
[docs] @abstractmethod
def encode(self, tokens, frags=None):
pass
[docs] @abstractmethod
def decode(self, representation):
pass
[docs] @staticmethod
@abstractmethod
def fromFile(path):
pass
[docs] @abstractmethod
def toFile(self, path):
pass
[docs]class SequenceVocabulary(Vocabulary, ABC):
"""
Generic vocabulary for sequence-based models.
"""
def __init__(self, encode_frags, words, max_len=100, min_len=10):
"""
Args:
encode_frags: boolean indicating if used to also encode fragments
words: iterable of words in this vocabulary
max_len: the maximum number of tokens contained in one SMILES
"""
super().__init__(words)
if encode_frags: # Allow fragments for fragment-based models
self.control = ('_', 'GO', 'EOS') # '_' used during model fitting
self.special = list(self.control) + ['.']
else:
self.control = ('GO', 'EOS')
self.special = list(self.control)
self.wordSet = set()
if words:
self.wordSet = set(x for x in words if x not in self.special)
self.updateIndex()
self.max_len = max_len
self.min_len = min_len
[docs] @abstractmethod
def splitSequence(self, seq):
pass
[docs] def toFile(self, path):
log = open(path, 'w')
log.write('\n'.join([x for x in self.words if x not in self.special]))
log.close()
[docs] def addWordsFromSeq(self, seq, ignoreConstraints=False):
token = self.splitSequence(seq)
if ignoreConstraints or (self.min_len < len(token) <= self.max_len):
diff = set(token) - self.wordSet
if len(diff) > 0:
self.wordSet.update(diff)
self.updateIndex()
return token
else:
logger.warning(f"Molecule does not meet min/max words requirements (min: {self.min_len}, max: {self.max_len}). Words found: {set(token)} (occurrence count: {len(token)}). It will be ignored.")
return None
[docs] def removeIfNew(self, seq, ignoreConstraints=False):
token = self.splitSequence(seq)
if ignoreConstraints or (self.min_len < len(token) <= self.max_len):
diff = set(token) - self.wordSet - set(self.special)
if len(diff) > 0:
logger.warning(f"Tokens: {set(diff)} do not occur in voc. Molecule: {seq} will be ignored.")
return None
else:
return token
else:
logger.warning(f"Molecule does not meet min/max words requirements (min: {self.min_len}, max: {self.max_len}). Words found: {set(token)} (occurrence count: {len(token)}). It will be ignored.")
return None
[docs] def updateIndex(self):
self.words = self.special + [x for x in sorted(self.wordSet) if x not in self.special]
self.size = len(self.words)
self.tk2ix = dict(zip(self.words, range(len(self.words))))
self.ix2tk = {v: k for k, v in self.tk2ix.items()}
[docs]class Corpus(MolSupplier, ABC):
"""
A `MolSupplier` that generates encoded molecule data from the given input.
"""
def __init__(self, molecules):
"""
Args:
molecules: an `iterable`, `MolSupplier` or a `list`-like data structure to supply molecules
"""
super().__init__()
self.molecules = molecules if hasattr(molecules, "__next__") else iter(molecules)
[docs] def next(self):
return next(self.molecules)
[docs] def convert(self, representation):
try:
ret = self.processMolecule(representation)
except Exception as exp:
logger.warning(f'Exception occurred when generating corpus data for molecule: {representation}. Cause:')
logger.exception(exp)
return next(self)
return ret
[docs] @abstractmethod
def processMolecule(self, molecule):
"""
Process one molecule.
Args:
molecule: a molecule instance (representation depend on the implementation).
Returns:
encoded data of the molecule (i.e. data associated with one input sample to the desired DrugEx model)
"""
pass
[docs] @abstractmethod
def getVoc(self):
"""
Corpus should keep track of the 'Vocabulary' used to encode molecules. This method should return its current state.
Returns:
currently used `Vocabulary`
"""
pass