import pandas as pd
from drugex.data.corpus.vocabulary import VocSmiles, VocGraph
from drugex.logs import logger
from drugex.data.interfaces import DataSplitter, FragmentPairEncoder
from drugex.molecules.converters.interfaces import ConversionException
from drugex.molecules.interfaces import MolSupplier
from drugex.parallel.collectors import ListExtend
from drugex.parallel.evaluator import ParallelSupplierEvaluator
from drugex.parallel.interfaces import ParallelProcessor
[docs]class SequenceFragmentEncoder(FragmentPairEncoder):
"""
Encode fragment-molecule pairs for the sequence-based models.
"""
def __init__(self, vocabulary=VocSmiles(True), update_voc=True, throw = False):
self.vocabulary = vocabulary
self.updateVoc = update_voc
self.throw = throw
[docs] def encodeMol(self, sequence):
"""
Encode a molecule sequence.
Args:
sequence: sequential representation of the molecule (i.e. SMILES)
Returns:
a `tuple` containing the obtained tokens from the sequence (if any) and the corresponding sequence of codes
"""
tokens = None
if self.updateVoc:
tokens = self.vocabulary.addWordsFromSeq(sequence)
elif self.throw:
tokens = self.vocabulary.removeIfNew(sequence)
if tokens:
# Encode all but end tokens
output = self.vocabulary.encode([tokens[: -1]])
code = output[0].reshape(-1).tolist()
return tokens, code
return tokens, None
[docs] def encodeFrag(self, mol, mol_tokens, frag):
"""Encode a fragment.
Is called by `FragmentPairsEncodedSupplier` with the `mol`
argument being the output of the above `encodeMol` method.
"""
tokens = None
if self.updateVoc:
tokens = self.vocabulary.addWordsFromSeq(frag, ignoreConstraints=True)
elif self.throw:
tokens = self.vocabulary.removeIfNew(frag, ignoreConstraints=True)
if tokens:
# Encode all but end tokens
output = self.vocabulary.encode([tokens[: -1]])
code = output[0].reshape(-1).tolist()
return code
[docs] def getVoc(self):
return self.vocabulary
[docs]class GraphFragmentEncoder(FragmentPairEncoder):
"""
Encode molecules and fragments for the graph-based transformer (`GraphModel`).
"""
def __init__(self, vocabulary=VocGraph()):
"""
Initialize this instance with the vocabulary to use.
Args:
vocabulary: used to perform the encoding
"""
self.vocabulary = vocabulary
[docs] def encodeMol(self, smiles):
"""
Molecules are encoded together with fragments -> we just pass the smiles back as both tokens and result of encoding.
Args:
smiles:
Returns:
The input smiles as both the tokens and as the encoded result.
"""
return smiles, smiles
[docs] def encodeFrag(self, mol, mol_tokens, frag):
"""
Encode molecules and fragments at once.
Args:
mol: parent molecule SMILES (from `encodeMol`)
mol_tokens: molecule SMILES (from `encodeMol`)
frag: SMILES of the fragment in the parent molecule
Returns:
One line of the graph-encoded data.
"""
if mol == frag:
return None
try:
output = self.vocabulary.encode([mol], [frag])
f, s = self.vocabulary.decode(output)
assert mol == s[0]
#assert f == frag[0]
code = output[0].reshape(-1).tolist()
return code
except Exception as exp:
logger.warn(f'The following exception occured while encoding fragment {frag} for molecule {mol}: {exp}')
return None
[docs] def getVoc(self):
return self.vocabulary
[docs]class FragmentPairsEncodedSupplier(MolSupplier):
"""
Transforms fragment-molecule pairs to the encoded representation used by the fragment-based DrugEx models.
"""
[docs] class FragmentEncodingException(ConversionException):
"""
Raise this when a fragment failed to encode.
"""
pass
[docs] class MoleculeEncodingException(ConversionException):
"""
Raise this when the parent molecule of the fragment failed to be encoded.
"""
pass
def __init__(self, pairs, encoder):
"""
Initialize from a `DataFrame` containing the fragment-molecule pairs.
Args:
pairs (list): list of (fragment, molecule) `tuple`s that each denotes one fragment-molecule pair
encoder: a `FragmentPairEncoder` handling encoding of molecules and fragments
"""
self.encoder = encoder
self.pairs = iter(pairs)
[docs] def next(self):
"""
Get the next pair and encode it with the encoder.
Returns:
`tuple`: (str, str) encoded form of fragment-molecule pair
"""
pair = next(self.pairs) # (fragment, molecule)
# encode molecule
tokens, encoded_mol = self.encoder.encodeMol(pair[1])
if not tokens:
raise self.MoleculeEncodingException(f'Failed to encode molecule: {pair[1]}')
# encode fragment
encoded_frag = self.encoder.encodeFrag(pair[1], tokens, pair[0])
if not encoded_frag:
raise self.FragmentEncodingException(f'Failed to encode fragment {pair[0]} from molecule: {pair[1]}')
return encoded_frag, encoded_mol
[docs]class FragmentCorpusEncoder(ParallelProcessor):
"""
Fragments and encodes fragment-molecule pairs in parallel. Each encoded pair is used as input to the fragment-based DrugEx models.
"""
[docs] class FragmentPairsCollector(ListExtend):
"""
A simple `ResultCollector` that extends an internal `list`. It can also wrap another instance of itself.
"""
def __init__(self, other=None):
"""
Args:
other: another instance of `FragmentPairsCollector` to call after extending
"""
super().__init__()
self.other = other
def __call__(self, result):
self.items.extend(result[0])
if self.other:
self.other(result)
def __init__(self, fragmenter, encoder, pairs_splitter=None, n_proc=None, chunk_size=None):
"""
Args:
fragmenter (MolConverter): a `MolConverter` that returns a `list` of (fragment, molecule) `tuple`s for a given molecule supplied as its SMILES string. See the reference implementation in `Fragmenter`.
encoder: a `FragmentPairEncoder` that handles how molecules and fragments are encoded in the final result
pairs_splitter: a `ChunkSplitter` that divides the generated molecule-fragment pairs from the "fragmenter" to splits (i.e. test and train)
n_proc: number of processes to use for parallel operations
chunk_size: maximum size of data chunks processed by a single process (can save memory)
"""
super().__init__(n_proc, chunk_size)
self.fragmenter = fragmenter
self.encoder = encoder
self.pairsSplitter = pairs_splitter
[docs] def getFragmentPairs(self, mols, collector):
"""
Apply the given "fragmenter" in parallel.
Args:
mols: Molecules represented as SMILES strings.
collector: The `ResultCollector` to apply to fetch the result per process.
Returns:
`None`
"""
evaluator = ParallelSupplierEvaluator(
FragmentPairsSupplier,
kwargs={
"fragmenter" : self.fragmenter
},
chunk_size=self.chunkSize,
chunks=self.chunks,
n_proc=self.nProc
)
evaluator.apply(mols, collector, desc_string="Creating fragment-molecule pairs")
[docs] def splitFragmentPairs(self, pairs):
"""
Use the "pairs_splitter" to get splits of the calculated molecule-fragment pairs from `FragmentCorpusEncoder.getFragmentPairs()`
Args:
pairs: pairs generated by the "fragmenter"
Returns:
splits from the specified "splitter"
"""
return self.pairsSplitter(pairs) if self.pairsSplitter else [pairs]
[docs] def encodeFragments(self, pairs, collector):
"""
Encodes fragment-pairs obtained from `FragmentCorpusEncoder.getFragmentPairs()` with the specified `FragmentPairEncoder` initialized in "encoder".
Args:
pairs: `list` of (fragment, molecule) `tuple`s to encode
collector: The `ResultCollector` to apply to fetch encoding data from each process.
Returns:
`None`
"""
evaluator = ParallelSupplierEvaluator(
FragmentPairsEncodedSupplier,
kwargs={
'encoder': self.encoder,
},
chunk_size=self.chunkSize,
chunks=self.chunks,
n_proc=self.nProc
)
evaluator.apply(pairs, collector, desc_string="Encoding fragment-molecule pairs.")
[docs] def apply(self, mols, fragmentPairsCollector=None, encodingCollectors=None):
"""
Apply fragmentation and encoding to the given molecules represented as SMILES strings. Collectors can be used to fetch fragment-molecule pairs and the final encoding with vocabulary.
Args:
mols: `list` of molecules as SMILES strings
fragmentPairsCollector: an instance of `ResultCollector` to collect results of the fragmentation (the generated fragment-molecule `tuple`s from the given "fragmenter").
encodingCollectors: a `list` of `ResultCollector` instances matching in length the number of splits given by the "pairs_splitter". Each `ResultCollector` receives a (data, `FragmentPairsEncodedSupplier`) `tuple` of the currently finished process.
Returns:
`None`
"""
pairs_collector = self.FragmentPairsCollector(fragmentPairsCollector)
self.getFragmentPairs(mols, pairs_collector)
splits = self.splitFragmentPairs(pairs_collector.getList())
if encodingCollectors and len(encodingCollectors) != len(splits):
raise RuntimeError(f'The number of encoding collectors must match the number of splits: {len(encodingCollectors)} != {len(splits)}')
for split_idx in range(len(splits)):
self.encodeFragments(splits[split_idx], encodingCollectors[split_idx] if encodingCollectors else None)