"""Different splitters to create train and tests for evalutating QSPR model performance.
To add a new data splitter:
* Add a DataSplit subclass for your new splitter
"""
import platform
from abc import ABC, abstractmethod
from typing import Iterable
import numpy as np
import pandas as pd
from gbmtsplits import GloballyBalancedSplit
from sklearn.model_selection import ShuffleSplit
from ...data.chem.clustering import (
FPSimilarityMaxMinClusters,
MoleculeClusters,
RandomClusters,
ScaffoldClusters,
)
from ...data.chem.scaffolds import BemisMurckoRDKit, Scaffold
from ...data.tables.interfaces.data_set_dependent import DataSetDependent
from ...data.tables.interfaces.qspr_data_set import QSPRDataSet
from ...logs import logger
from ...utils.interfaces.randomized import Randomized
from ...utils.serialization import JSONSerializable
[docs]
class DataSplit(JSONSerializable, ABC):
"""Defines a function to split a dataframe into train and test set."""
[docs]
@abstractmethod
def split(
self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.DataFrame | pd.Series
) -> Iterable[tuple[list[int], list[int]]]:
"""Split the given data into one or multiple train/test subsets.
These classes handle partitioning of a feature matrix
by returning an generator of train
and test indices. It is compatible with the approach taken
in the `sklearn` package (see `sklearn.model_selection._BaseKFold`).
This can be used for both cross-validation or a one time train/test split.
Args:
X (np.ndarray | pd.DataFrame): the input data matrix
y (np.ndarray | pd.DataFrame | pd.Series): the target variable(s)
Returns:
an generator over the generated subsets represented as a tuple of
(train_indices, test_indices) where the indices are the row indices of the
input data matrix X (note that these are integer indices, rather than a
pandas index!)
"""
[docs]
def splitDataset(self, dataset: QSPRDataSet):
return dataset.split(self)
[docs]
class RandomSplit(DataSplit, Randomized):
"""Splits dataset in random train and test subsets.
Attributes:
testFraction (float):
fraction of total dataset to testset
seed (int):
Random state to use for shuffling and other random operations.
"""
def __init__(
self,
test_fraction=0.1,
seed: int | None = None,
) -> None:
self.testFraction = test_fraction
self.randomState = seed
@property
def randomState(self) -> int:
return self._seed
@randomState.setter
def randomState(self, seed: int | None):
self._seed = seed
[docs]
def split(self, X, y):
if self.randomState is None:
logger.info(
"No random state supplied, "
"and could not find random state on the dataset."
"Random seed will be set randomly."
)
return ShuffleSplit(
1, test_size=self.testFraction, random_state=self.randomState
).split(X, y)
[docs]
class BootstrapSplit(DataSplit, Randomized, DataSetDependent):
"""Splits dataset in random train and test subsets (bootstraps). Unlike
cross-validation, bootstrapping allows for repeated samples in the test set.
Attributes:
nBootstraps (int):
number of bootstraps to perform
seed (int):
Random state to use for shuffling and other random operations.
"""
def __init__(self, split: DataSplit, n_bootstraps=5, seed=None, dataset=None):
"""Initialize a BootstrapSplit object.
Args:
split (DataSplit): the splitter to use for the bootstraps
n_bootstraps (int): number of bootstraps to perform
seed (int): random seed to use for random operations
dataset (QSPRDataSet): dataset for the underlying splitter if it is
`DataSetDependent`
"""
super().__init__(dataset)
self._split = split
self._original_split_seed = (
split.randomState if hasattr(split, "randomState") else None
)
self.nBootstraps = n_bootstraps
self._current = 0
self._seed = seed
@property
def randomState(self) -> int:
return self._seed
@randomState.setter
def randomState(self, seed: int | None):
self._seed = seed
[docs]
def split(
self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.DataFrame | pd.Series
) -> Iterable[tuple[list[int], list[int]]]:
"""Split the given data into `nBootstraps` training and test sets.
Args:
X (np.ndarray | pd.DataFrame): the input data matrix
y (np.ndarray | pd.DataFrame | pd.Series): the target variable(s)
Returns:
an generator over `nBootstraps` tuples generated by the underlying splitter
"""
for i in range(self.nBootstraps):
if hasattr(self._split, "randomState") and self.randomState is not None:
# iterate the random state of the underlying splitter, to get different
# splits for each bootstrap
self._split.randomState = self.randomState + self._current
yield from self._split.split(X, y)
self._current += 1
# reset the random state of the underlying splitter
if hasattr(self._split, "randomState"):
self._split.randomState = self._original_split_seed
self._current = 0
[docs]
def setDataSet(self, dataset):
"""Set the dataset for the underlying splitter."""
super().setDataSet(dataset)
if hasattr(self._split, "setDataSet"):
self._split.setDataSet(dataset)
[docs]
class ManualSplit(DataSplit, DataSetDependent):
"""Splits dataset in train and test subsets based on a column in the dataframe.
Attributes:
splitProp (str | list): name(s) of the column(s) in the dataset that contains the split
trainVal (str): value in splitcol that will be used for training
testVal (str): value in splitcol that will be used for testing
Raises:
ValueError: if there are more values in splitcol than trainval and testval
"""
def __init__(
self,
splitprop: str | list,
trainval: str,
testval: str,
data_set: QSPRDataSet | None = None
) -> None:
"""Initialize the ManualSplit object with the splitcol, trainval and testval
attributes.
One or more columns can be provided in splitprop to generate multiple splits,
e.g. like cross-validation.
Args:
splitprop (str | list): name(s) of the column(s) in the dataset that contain(s) the split
trainval (str): value in a splitprop that will be used for training
testval (str): value in splitprop that will be used for testing
data_set (QSPRDataSet): dataset that this splitter will be acting on
"""
super().__init__(data_set)
if isinstance(splitprop, list):
self.splitProps = splitprop
else:
self.splitProps = [splitprop]
self.trainVal = trainval
self.testVal = testval
[docs]
def split(self, X, y):
"""
Split the given data into one or multiple train/test subsets based on the
predefined splitprop(s).
Args:
X (np.ndarray | pd.DataFrame): the input data matrix
y (np.ndarray | pd.DataFrame | pd.Series): the target variable(s)
Returns:
an generator over the generated subsets represented as a tuple of
(train_indices, test_indices) where the indices are the row indices of the
input data matrix
"""
assert self.hasDataSet, (
"No dataset attached to this splitter, set dataset with setDataSet()"
)
df = self.dataSet.getDF()
splits = []
for splitprop in self.splitProps:
assert splitprop in df.columns, f"Column {splitprop} not found in dataset"
splitcol = df[splitprop]
# check if only trainval and testval are present in splitcol
if not set(splitcol.unique()).issubset({self.trainVal, self.testVal}):
raise ValueError(
f"There are more values in splitprop {splitprop} than trainval and testval"
)
# Check if all samples are assigned to either train or test
assert all(
y.isin(splitcol.index)
), "Not all samples are assigned to either train or test"
# get indices of train and test samples
all_train = splitcol[splitcol == self.trainVal].index
train = np.where(y.index.isin(all_train))[0]
all_test = splitcol[splitcol == self.testVal].index
test = np.where(y.index.isin(all_test))[0]
splits.append((train, test))
return iter(splits)
[docs]
class TemporalSplit(DataSplit, DataSetDependent):
"""Splits dataset train and test subsets based on a threshold in time.
Attributes:
timeSplit(float): time point after which sample to test set
timeCol (str): name of the column within the dataframe with timepoints
"""
def __init__(
self,
timesplit: float | list[float],
timeprop: str,
data_set: QSPRDataSet | None = None,
):
"""Initialize a TemporalSplit object.
Args:
timesplit (float | list[float]):
time point after which sample is moved to test set. If a list is
provided, the splitter will split the dataset into multiple subsets
based on the timepoints in the list.
timeprop (str):
name of the column within the dataset with timepoints
dataset (QSPRDataSet):
dataset that this splitter will be acting on
"""
super().__init__(data_set)
self.timeSplit = timesplit
self.timeProp = timeprop
[docs]
def split(self, X, y):
"""Split single-task dataset based on a time threshold.
Args:
X (np.ndarray | pd.DataFrame): the input data matrix
y (np.ndarray | pd.DataFrame | pd.Series): the target variable(s)
Returns:
an generator over the generated subsets represented as a tuple of
(train_indices, test_indices) where the indices are the row indices of the
input data matrix
"""
assert self.hasDataSet, (
"No dataset attached to this splitter, set dataset with setDataSet()"
)
timesplits = (
self.timeSplit if isinstance(self.timeSplit, list) else [
self.timeSplit,
]
)
timeCol = self.dataSet.getDF()[self.timeProp]
timeCol = timeCol[y.index].copy().reset_index(drop=True)
for timesplit in timesplits:
# Get dataset, dataframe and tasks
task_names = y.columns if isinstance(y, pd.DataFrame) else [y.name]
if len(task_names) > 1:
logger.warning(
"The TemporalSplit is not recommended for multitask "
"or PCM datasets might lead to very unbalanced subsets "
"for some tasks"
)
# make indices numeric
X_copy = X.copy().reset_index(drop=True)
y_copy = y.copy().reset_index(drop=True)
indices = X_copy.index.values
mask = timeCol > timesplit
test = indices[mask]
# Check if there are any test samples for each task
for task in task_names:
if len(y_copy[task][mask]) == 0:
raise ValueError(f"No test samples found for task {task.name}")
elif len(y_copy[task][~mask]) == 0:
raise ValueError(f"No train samples found for task {task.name}")
train = indices[~mask]
yield train, test
[docs]
class GBMTDataSplit(DataSplit, DataSetDependent):
"""Splits dataset into balanced train and test subsets
based on an initial clustering algorithm. If `nFolds` is specified,
the determined clusters will be split into `nFolds` groups of approximately
equal size, and the splits will be generated by leaving out one group at a time.
More information on the GBMT algorithm can be found at:
https://github.com/CDDLeiden/gbmt-splits
Attributes:
clustering (MoleculeClusters):
clustering algorithm to use
testFraction (float):
fraction of total dataset to test set, ignored if `nFolds` > 1
nFolds (int):
number of folds to split the dataset into
(this overrides `testFraction` and `customTestList`)
customTestList (list):
list of molecule indexes to force in test set,
ignored if `nFolds` > 1
splitKwargs (dict):
additional arguments to be passed to the GloballyBalancedSplit
"""
def __init__(
self,
clustering: MoleculeClusters = FPSimilarityMaxMinClusters(),
test_fraction: float = 0.1,
n_folds: int = 1,
custom_test_list: list[str] | None = None,
data_set: QSPRDataSet | None = None,
**split_kwargs,
):
"""Initialize a GBMTDataSplit object."""
super().__init__(data_set)
self.testFraction = test_fraction
self.customTestList = custom_test_list
self.clustering = clustering
self.splitKwargs = split_kwargs if split_kwargs else {}
self.nFolds = n_folds
if self.nFolds > 1:
self.testFraction = None
self.customTestList = None
[docs]
def split(
self, X: np.ndarray | pd.DataFrame, y: np.ndarray | pd.DataFrame | pd.Series
) -> Iterable[tuple[list[int], list[int]]]:
"""Split dataset into balanced train and test subsets
based on an initial clustering algorithm.
Args:
X (np.ndarray | pd.DataFrame): the input data matrix
y (np.ndarray | pd.DataFrame | pd.Series): the target variable(s)
Returns:
an generator over the generated subsets represented as a tuple of
(train_indices, test_indices) where the indices are the row indices of the
input data matrix
"""
assert self.hasDataSet, (
"No dataset attached to this splitter, set dataset with setDataSet()"
)
# if we are on Windows, raise an error
if platform.system() == "Windows":
logger.warning(
"The GBMTDataSplit currently has a problem on Windows:"
"https://github.com/coin-or/pulp/issues/671 and might hang up..."
)
# Get dataset, dataframe and tasks
y_index = y.index.copy()
smiles = self.dataSet.getDF()[self.dataSet.smilesProp][y_index].copy()
y.reset_index(drop=True, inplace=True) # need numeric index splits
task_names = y.columns if isinstance(y, pd.DataFrame) else [y.name]
assert len(task_names) > 0, "No target properties found."
# Get clusters
clusters = self.clustering.getClusters(smiles.to_list())
# Pre-assign smiles of custom_test_list to test set
preassigned_smiles = (
{
smiles[qspridx]: 1
for qspridx in self.customTestList
} if self.customTestList else None
)
logger.debug(f"Split arguments: {self.splitKwargs}")
# Split dataset
if self.nFolds == 1:
sizes = [1 - self.testFraction, self.testFraction]
else:
self.testFraction = (len(X) / self.nFolds) / len(X)
sizes = [self.testFraction] * self.nFolds
splitter = GloballyBalancedSplit(
sizes=sizes,
clusters=clusters,
clustering_method=None, # As precomputed clusters are provided
**self.splitKwargs,
)
y_copy = y.copy()
y_copy["SMILES"] = smiles.reset_index(drop=True)
y_split = splitter(
y_copy,
"SMILES",
task_names,
preassigned_smiles=preassigned_smiles,
)
# Get indices
for split in (y_split["Split"].unique() if self.nFolds > 1 else [
1,
]):
split = int(split)
test_indices = y_split[y_split["Split"] == split].index.values
train_indices = y_split[y_split["Split"] != split].index.values
assert len(train_indices) + len(test_indices) == len(
y
), "Not all samples were assigned to a split"
# Reset index back to QSPRID
y.set_index(y_index, inplace=True, drop=False)
yield train_indices, test_indices
[docs]
class GBMTRandomSplit(GBMTDataSplit, Randomized):
"""
Splits dataset into balanced random train and test subsets.
Attributes:
testFraction (float):
fraction of total dataset to testset
customTestList (list):
list of molecule indexes to force in test set
split_kwargs (dict):
additional arguments to be passed to the GloballyBalancedSplit
"""
def __init__(
self,
test_fraction: float = 0.1,
n_folds: int = 1,
seed: int | None = None,
n_initial_clusters: int | None = None,
custom_test_list: list[str] | None = None,
data_set: QSPRDataSet | None = None,
**split_kwargs,
) -> None:
if seed is None:
logger.info("No random state supplied")
super().__init__(
RandomClusters(seed, n_initial_clusters),
test_fraction,
n_folds,
custom_test_list,
data_set,
**split_kwargs,
)
self.initialClusters = n_initial_clusters
self.randomState = seed
@property
def randomState(self) -> int:
return self._seed
@randomState.setter
def randomState(self, seed: int | None):
self._seed = seed
super().__init__(
RandomClusters(seed, self.initialClusters),
self.testFraction,
self.nFolds,
self.customTestList,
self.dataSet,
**self.splitKwargs,
)
[docs]
class ScaffoldSplit(GBMTDataSplit):
"""
Splits dataset into balanced train and test subsets based on molecular scaffolds.
Attributes:
testFraction (float):
fraction of total dataset to testset
customTestList (list):
list of molecule indexes to force in test set
split_kwargs (dict):
additional arguments to be passed to the GloballyBalancedSplit
"""
def __init__(
self,
scaffold: Scaffold = BemisMurckoRDKit(),
test_fraction: float = 0.1,
n_folds: int = 1,
custom_test_list: list | None = None,
data_set: QSPRDataSet | None = None,
**split_kwargs,
) -> None:
super().__init__(
ScaffoldClusters(scaffold),
test_fraction,
n_folds,
custom_test_list,
data_set,
**split_kwargs,
)
[docs]
class ClusterSplit(GBMTDataSplit, Randomized):
"""Splits dataset into balanced train and test subsets based on clusters of similar
molecules.
Attributes:
testFraction (float):
fraction of total dataset to testset
customTestList (list):
list of molecule indexes to force in test set
seed (int):
Random state to use for shuffling and other random operations.
split_kwargs (dict):
additional arguments to be passed to the GloballyBalancedSplit
"""
def __init__(
self,
test_fraction: float = 0.1,
n_folds: int = 1,
custom_test_list: list[str] | None = None,
seed: int | None = None,
clustering: MoleculeClusters | None = None,
data_set: QSPRDataSet | None = None,
**split_kwargs,
) -> None:
if seed is None:
logger.info(
"No random state supplied, "
"and could not find random state on the dataset."
"Random seed will be set randomly."
)
self.clustering = (
clustering
if clustering is not None
else FPSimilarityMaxMinClusters(seed=seed)
)
self.randomState = seed
super().__init__(
self.clustering,
test_fraction,
n_folds,
custom_test_list,
data_set,
**split_kwargs,
)
@property
def randomState(self) -> int:
return self._seed
@randomState.setter
def randomState(self, seed: int | None):
self._seed = seed
if hasattr(self.clustering, "seed"):
self.clustering.seed = seed