Source code for qsprpred.benchmarks.settings.benchmark

from dataclasses import dataclass
from typing import ClassVar

from qsprpred.models.assessment.methods import ModelAssessor

from ...data.descriptors.sets import DescriptorSet
from ...data.sources.data_source import DataSource
from ...data.sampling.splits import DataSplit
from ...data.processing.pipeline import DatasetPipeline
from ...models.hyperparam_optimization import HyperparameterOptimization
from ...models.model import QSPRModel
from ...tasks import TargetSpec
from ...utils.serialization import JSONSerializable


[docs] @dataclass class BenchmarkSettings(JSONSerializable): """Class that determines settings for a benchmarking run. Attributes: name (str): Name of the benchmarking run. n_replicas (int): Number of replicas to run. random_seed (int): Random seed to use. data_sources (list[DataSource]): Data sources to use. descriptors (list[list[DescriptorSet]]): Descriptor sets to use. target_props (list[list[TargetProperty]]): Target properties to use. prep_settings (list[DatasetPipeline]): Data preparation settings to use. models (list[QSPRModel]): Models to use. assessors (list[ModelAssessor]): Model assessors to use. subsets (dict[str, tuple[DataSplit, str, int]]): Dictionary mapping assessor names to tuples of data split, set (Train/Test), and fold index. Used to apply assessors to subsets of the data. optimizers (list[HyperparameterOptimization]): Hyperparameter optimizers to use. """ _notJSON: ClassVar = ["models"] name: str n_replicas: int random_seed: int data_sources: list[DataSource] descriptors: list[list[DescriptorSet]] target_props: list[list[TargetSpec]] pipelines: list[DatasetPipeline] models: list[QSPRModel] assessors: list[ModelAssessor] subsets: dict[str, tuple[DataSplit, str, int]] = () optimizers: list[HyperparameterOptimization] = () def __getstate__(self): o_dict = super().__getstate__() o_dict["models"] = [] for model in self.models: o_dict["models"].append(model.save()) return o_dict def __setstate__(self, state): super().__setstate__(state) self.models = [QSPRModel.fromFile(model) for model in state["models"]]
[docs] def checkConsistency(self): """Checks if the settings are consistent. Raises: AssertionError: If the settings are inconsistent. """ assert len(self.data_sources) > 0, "No data sources defined." assert len(self.descriptors) > 0, "No descriptors defined." assert len(self.target_props) > 0, "No target properties defined." assert len(self.pipelines) > 0, "No data preparation settings defined." assert len(self.models) > 0, "No models defined." assert len(self.assessors) > 0, "No model assessors defined." assessor_names = [assessor.name for assessor in self.assessors] if len(self.subsets) > 0: for assessor in self.subsets.keys(): assert assessor in assessor_names, f"Assessor {assessor} in subsets not found in assessors." assert len(self.subsets[assessor]) == 3, "Subsets must be a tuple of DataSplit, set (Train/Test), and fold index." assert isinstance(self.subsets[assessor][0], DataSplit), "First element of subset must be a DataSplit." assert self.subsets[assessor][1] in ["Train", "Test"], "Second element of subset must be 'Train' or 'Test'." assert isinstance(self.subsets[assessor][2], int), "Third element of subset must be an integer, the fold index."