Source code for drugex.data.datasets

"""
defaultdatasets

Created by: Martin Sicho
On: 25.06.22, 19:42
"""

from itertools import chain
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset

from drugex.data.corpus.vocabulary import VocSmiles, VocGraph
from drugex.data.interfaces import DataSet, DataToLoader


[docs]class SmilesDataSet(DataSet): """ `DataSet` that holds the encoded SMILES representations of molecules for the single-network sequence-based DrugEx model (`RNN`). """ columns=('Smiles', 'Token') # column names to use for the data frame def __init__(self, path, voc=None, rewrite=False): super().__init__(path, rewrite=rewrite) self.setVoc(voc if voc else VocSmiles(False))
[docs] @staticmethod def dataToLoader(data, batch_size, vocabulary): dataset = torch.from_numpy(data).long().view(len(data), vocabulary.max_len) loader = DataLoader(dataset, batch_size=batch_size, drop_last=False, shuffle=True) return loader
def __call__(self, result): """ Collect results from `SequenceCorpus`. Args: result: A list of items generated from `SequenceCorpus` to be added to the current `DataSet`. Returns: `None` """ self.updateVoc(result[1].getVoc()) self.sendDataToFile(result[0], columns=self.getColumns())
[docs] def getColumns(self): return ['C%d' % d for d in range(self.getVoc().max_len)]
[docs] def readVocs(self, paths, voc_class, *args, **kwargs): super().readVocs(paths, voc_class=voc_class, *args, **kwargs)
[docs]class SmilesFragDataSet(DataSet): """ `DataSet` that holds the encoded SMILES representations of fragment-molecule pairs for the sequence-based encoder-decoder type of DrugEx models. """ columns=('Input', 'Output')
[docs] class TargetCreator(DataToLoader): """ Old creator for test data that currently is no longer being used. Saved here for future reference. """
[docs] class TgtData(Dataset): def __init__(self, seqs, ix, max_len=100): self.max_len = max_len self.index = np.array(ix) self.map = {idx: i for i, idx in enumerate(self.index)} self.seq = seqs def __getitem__(self, i): seq = self.seq[i] return i, seq def __len__(self): return len(self.seq)
[docs] def collate_fn(self, arr): collated_ix = np.zeros(len(arr), dtype=int) collated_seq = torch.zeros(len(arr), self.max_len).long() for i, (ix, tgt) in enumerate(arr): collated_ix[i] = ix collated_seq[i, :] = tgt return collated_ix, collated_seq
def __call__(self, data, batch_size, vocabulary): dataset = data[:,0] dataset = pd.Series(dataset).drop_duplicates() dataset = [seq.split(' ') for seq in dataset] dataset = vocabulary.encode(dataset) dataset = self.TgtData(dataset, ix=[vocabulary.decode(seq, is_tk=False) for seq in dataset]) dataset = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) return dataset
def __init__(self, path, voc=None, rewrite=False, save_voc=True, voc_file=None): super().__init__(path, rewrite=rewrite, save_voc=save_voc, voc_file=voc_file) self.voc = voc if voc else VocSmiles(True) def __call__(self, result): """ Collect encoded data from the given results. Designated to be used as the collector of encodings for `FragmentCorpusEncoder`. Args: result: `tuple` with two items -- a `list` of `tuple`s as supplied by: `FragmentPairsEncodedSupplier` and the `FragmentPairsEncodedSupplier` itself Returns: `None` """ self.updateVoc(result[1].encoder.getVoc()) self.sendDataToFile([list(chain.from_iterable(x)) for x in result[0]], columns=self.getColumns())
[docs] def createLoaders(self, data, batch_size, splitter=None, converter=None): splits = [] if splitter: splits = splitter(data) else: splits.append(data) return [converter(split, batch_size, self.getVoc()) if converter else split for split in splits]
[docs] @staticmethod def dataToLoader(data, batch_size, vocabulary): # Split into molecule and fragment embedding dataset = TensorDataset(torch.from_numpy(data[:, :vocabulary.max_len]).long().view(len(data), vocabulary.max_len), torch.from_numpy(data[:, vocabulary.max_len:]).long().view(len(data), vocabulary.max_len)) loader = DataLoader(dataset, batch_size=batch_size, drop_last=False, shuffle=True) return loader
[docs] def getColumns(self): return ['C%d' % d for d in range(self.getVoc().max_len * 2)]
[docs] def readVocs(self, paths, voc_class, *args, **kwargs): super().readVocs(paths, voc_class=voc_class, *args, **kwargs)
[docs]class GraphFragDataSet(DataSet): """ `DataSet` to manage the fragment-molecule pair encodings for the graph-based model (`GraphModel`). """ def __init__(self, path, voc=None, rewrite=False, save_voc=True, voc_file=None): super().__init__(path, rewrite=rewrite, save_voc=save_voc, voc_file=voc_file) self.voc = voc if voc else VocGraph() def __call__(self, result): """ Collect encoded data from the given results. Designated to be used as the collector of encodings for `FragmentCorpusEncoder`. Args: result: `tuple` with two items -- a `list` of `tuple`s as supplied by: `FragmentPairsEncodedSupplier` and the `FragmentPairsEncodedSupplier` itself Returns: `None` """ self.updateVoc(result[1].encoder.getVoc()) data = [x[0] for x in result[0]] self.sendDataToFile(data, columns=self.getColumns())
[docs] def getColumns(self): return ['C%d' % d for d in range(self.getVoc().max_len * 5)]
[docs] @staticmethod def dataToLoader(data, batch_size, vocabulary): dataset = torch.from_numpy(data).long().view(len(data), vocabulary.max_len, -1) loader = DataLoader(dataset, batch_size=batch_size, drop_last=False, shuffle=True) return loader