Source code for qsprpred.models.early_stopping

"""Early stopping for training of models."""
from enum import Enum
from typing import Any, Callable

import numpy as np
import pandas as pd

from ..data.tables.qspr import QSPRDataset
from ..logs import logger
from ..utils.serialization import JSONSerializable


[docs]class EarlyStoppingMode(Enum): """Enum representing the type of early stopping to use. Attributes: NOT_RECORDING (str): early stopping, not recording number of epochs RECORDING (str):early stopping, recording number of epochs FIXED (str): no early stopping, specified number of epochs OPTIMAL (str): no early stopping, optimal number of epochs determined by previous training runs with early stopping (e.g. average number of epochs trained in cross validation with early stopping) """ NOT_RECORDING = "NOT_RECORDING" RECORDING = "RECORDING" FIXED = "FIXED" OPTIMAL = "OPTIMAL" def __str__(self) -> str: """Return the name of the task.""" return self.name def __bool__(self) -> bool: """Return whether early stopping is used.""" return self in [EarlyStoppingMode.NOT_RECORDING, EarlyStoppingMode.RECORDING]
[docs]class EarlyStopping(JSONSerializable): """Early stopping tracker for training of QSPRpred models. An instance of this class is used to track the number of epochs trained in a model when early stopping (mode RECORDING) is used. This information can then be used to determine the optimal number of epochs to train in a model training without early stopping (mode OPTIMAL). The optimal number of epochs is determined by aggregating the number of epochs trained in previous model trainings with early stopping. The aggregation function can be specified by the user. The number of epochs to train in a model training without early stopping can also be specified manually (mode FIXED). Models can also be trained with early stopping without recording the number of epochs trained (mode NOT_RECORDING), e.g. useful when hyperparameter tuning is performed with early stopping. Attributes: mode (EarlyStoppingMode): early stopping mode numEpochs (int): number of epochs to train in FIXED mode. aggregatefunc (function): numpy function to aggregate trained epochs in OPTIMAL mode. Defaults to np.mean. trainedEpochs (list[int]): list of number of epochs trained in a model training with early stopping on RECORDING mode. """ def __init__( self, mode: EarlyStoppingMode = EarlyStoppingMode.NOT_RECORDING, num_epochs: int | None = None, aggregate_func: Callable[[list[int]], int] = np.mean, ): """Initialize early stopping. Args: mode (EarlyStoppingMode): early stopping mode num_epochs (int, optional): number of epochs to train in FIXED mode. aggregate_func (function, optional): numpy function to aggregate trained epochs in OPTIMAL mode. Note, non-numpy functions are not supported. """ self.mode = mode self.numEpochs = num_epochs self._trainedEpochs = [] self.aggregateFunc = aggregate_func def __getstate__(self): state = super().__getstate__() state["aggregateFunc"] = self.aggregateFunc.__name__ return state def __setstate__(self, state): super().__setstate__(state) self.aggregateFunc = getattr(np, self.aggregateFunc) @property def optimalEpochs(self) -> int: """Return number of epochs to train in OPTIMAL mode.""" if len(self._trainedEpochs) == 0: raise ValueError( "No number of epochs have been recorded yet, first run fit with early " "stopping mode set to RECORDING or set the optimal number of epochs " "manually." ) optimal_epochs = int(np.round(self.aggregateFunc(self._trainedEpochs))) logger.debug(f"Optimal number of epochs: {optimal_epochs}") return optimal_epochs @property def trainedEpochs(self) -> list[int]: """Return list of number of epochs trained in early stopping mode RECORDING.""" return self._trainedEpochs.copy() @trainedEpochs.setter def trainedEpochs(self, epochs: list[int]): """Set list of number of epochs trained in early stopping mode RECORDING." Args: epochs (list[int]): list of number of epochs """ self._trainedEpochs = epochs
[docs] def recordEpochs(self, epochs: int): """Record number of epochs. Args: epochs (int): number of epochs """ logger.debug(f"Recorded best epoch: {epochs}") self._trainedEpochs.append(epochs)
[docs] def getEpochs(self) -> int: """Get the number of epochs to train in a non-early stopping mode.""" if self.mode == EarlyStoppingMode.FIXED: return self.numEpochs else: return self.optimalEpochs
def __str__(self) -> str: """Return the name of the task.""" return self.mode.name def __bool__(self) -> bool: """Return whether early stopping is used.""" return self.mode.__bool__()
[docs] def clean(self): """Clean early stopping object.""" self._trainedEpochs = []
[docs]def early_stopping(func: Callable) -> Callable: """Early stopping decorator for fit method of models that support early stopping. Returns: function: decorated fit method """ def wrapper_fit( self, X: pd.DataFrame | np.ndarray | QSPRDataset, y: pd.DataFrame | np.ndarray | QSPRDataset, estimator: Any | None = None, mode: EarlyStoppingMode | None = None, split: "DataSplit" = None, monitor: "FitMonitor" = None, **kwargs, ) -> Any: """Wrapper for fit method of models that support early stopping. Args: X (pd.DataFrame, np.ndarray, QSPRDataset): data matrix to fit y (pd.DataFrame, np.ndarray, QSPRDataset): target matrix to fit estimator (Any): estimator instance to use for fitting mode (EarlyStoppingMode): early stopping mode split (DataSplit): data split to use for early stopping, if None, a ShuffleSplit with 10% validation set size is used monitor (FitMonitor): monitor to use for fitting, if None, a BaseMonitor is used kwargs (dict): additional keyword arguments for the estimator's fit method Returns: Any: fitted estimator instance """ assert self.supportsEarlyStopping, ( "early_stopping decorator can only be used for models that support" " early stopping." ) self.earlyStopping.mode = mode if mode is not None else self.earlyStopping.mode estimator, best_epoch = func( self, X, y, estimator, mode, split, monitor, **kwargs ) if self.earlyStopping.mode == EarlyStoppingMode.RECORDING: self.earlyStopping.recordEpochs(best_epoch + 1) # +1 for 0-indexing return estimator return wrapper_fit