Source code for qsprpred.tasks

from enum import Enum
from typing import Literal, Optional, Callable

from qsprpred.utils.serialization import (
    JSONSerializable,
    function_as_string,
    function_from_string,
)


[docs]class TargetTasks(Enum): """Enum representing the type of task the model is supposed to perform for a single target property. """ REGRESSION = "REGRESSION" SINGLECLASS = "SINGLECLASS" MULTICLASS = "MULTICLASS"
[docs] def isClassification(self): """Check if the task is a classification task.""" return self in [self.SINGLECLASS, self.MULTICLASS]
[docs] def isRegression(self): """Check if the task is a regression task.""" return self in [self.REGRESSION]
def __str__(self): """Return the name of the task.""" return self.name
[docs]class ModelTasks(Enum): """Enum representing the general type of task the model is supposed to perform for all target properties. """ REGRESSION = "REGRESSION" SINGLECLASS = "SINGLECLASS" MULTICLASS = "MULTICLASS" MULTITASK_REGRESSION = "MULTITASK_REGRESSION" MULTITASK_SINGLECLASS = "MULTITASK_SINGLECLASS" MULTITASK_MULTICLASS = "MULTITASK_MULTICLASS" MULTITASK_MIXED = "MULTITASK_MIXED"
[docs] def isClassification(self): """Check if the task is a classification task.""" return self in [ self.SINGLECLASS, self.MULTICLASS, self.MULTITASK_SINGLECLASS, self.MULTITASK_MULTICLASS, ]
[docs] def isRegression(self): """Check if the task is a regression task.""" return self in [self.REGRESSION, self.MULTITASK_REGRESSION]
[docs] def isMixed(self): """Check if the task is a mixed task.""" return self in [self.MULTITASK_MIXED]
[docs] def isMultiTask(self): """Check if the task is a multitask task.""" return self in [ self.MULTITASK_REGRESSION, self.MULTITASK_SINGLECLASS, self.MULTITASK_MULTICLASS, self.MULTITASK_MIXED, ]
def __str__(self): """Return the name of the task.""" return self.name
[docs] @staticmethod def getModelTask(target_properties: list): """Return the model type for a given list of target properties.""" if len(target_properties) == 1: return ModelTasks[target_properties[0].task.name] elif all( target_property.task.isRegression() for target_property in target_properties ): return ModelTasks.MULTITASK_REGRESSION elif all( target_property.task.isClassification() for target_property in target_properties ): if all( target_property.task == TargetTasks.SINGLECLASS for target_property in target_properties ): return ModelTasks.MULTITASK_SINGLECLASS else: return ModelTasks.MULTITASK_MULTICLASS else: return ModelTasks.MULTITASK_MIXED
[docs]class TargetProperty(JSONSerializable): """Target property for QSPRmodelling class. Attributes: name (str): name of the target property task (Literal[TargetTasks.REGRESSION, TargetTasks.SINGLECLASS, TargetTasks.MULTICLASS]): task type for the target property th (int): threshold for the target property, only used for classification tasks nClasses (int): number of classes for the target property, only used for classification tasks transformer (Callable): function to transform the target property imputer (Callable): function to impute the target property """ _notJSON = ["transformer", *JSONSerializable._notJSON] def __init__( self, name: str, task: Literal[ TargetTasks.REGRESSION, TargetTasks.SINGLECLASS, TargetTasks.MULTICLASS ], th: Optional[list[float] | str] = None, n_classes: Optional[int] = None, transformer: Optional[Callable] = None, imputer: Optional[Callable] = None, ): """Initialize a TargetProperty object. Args: name (str): name of the target property task (Literal[TargetTasks.REGRESSION, TargetTasks.SINGLECLASS, TargetTasks.MULTICLASS]): task type for the target property th (list[float] | str): threshold for the target property, only used for classification tasks. If th is precomputed, set it to "precomputed". If th is precomputed, n_classes must be specified. n_classes (int): number of classes for the target property. Must be specified if th is precomputed, otherwise it is inferred from th. transformer (Callable): function to transform the target property imputer (Callable): function to impute the target property """ self.name = name self.task = task if task.isClassification(): assert ( th is not None ), (f"Threshold not specified for classification task `{name}`. " "If the task is already precomputed, set `th` to `precomputed`, and " "define the correct number of classes with `n_classes." ) self.th = th if isinstance(th, str) and th == "precomputed": self.nClasses = n_classes self.transformer = transformer self.imputer = imputer def __getstate__(self): o_dict = super().__getstate__() o_dict["transformer"] = function_as_string(self.transformer) if self.transformer else None return o_dict def __setstate__(self, state): super().__setstate__(state) if state["transformer"] is not None: self.transformer = function_from_string(self.transformer) @property def th(self): """Set the threshold for the target property. Returns: th ([list[int] | str]): threshold for the target property """ return self._th @th.setter def th(self, th: list[float] | str): """Set the threshold for the target property and the number of classes if th is not precomputed. Args: th (list[float] | str): threshold for the target property """ assert ( self.task.isClassification() ), "Threshold can only be set for classification tasks" self._th = th if isinstance(th, str): assert th == "precomputed", f"Invalid threshold {th}" else: self._nClasses = len(self.th) - 1 if len(self.th) > 1 else 2 @th.deleter def th(self): """Delete the threshold for the target property and the number of classes.""" del self._th del self._nClasses @property def nClasses(self): """Get the number of classes for the target property. Returns: nClasses (int): number of classes """ return self._nClasses @nClasses.setter def nClasses(self, nClasses: int): """Set the number of classes for the target property if th is precomputed. Args: nClasses (int): number of classes """ assert ( self.th == "precomputed" ), "Number of classes can only be set if threshold is precomputed" self._nClasses = nClasses def __repr__(self): """Representation of the TargetProperty object.""" if self.task.isClassification() and self.th is not None: return f"TargetProperty(name={self.name}, task={self.task}, th={self.th})" else: return f"TargetProperty(name={self.name}, task={self.task})" def __str__(self): """Return string identifier of the TargetProperty object.""" return self.name
[docs] @classmethod def fromDict(cls, d: dict): """Create a TargetProperty object from a dictionary. task can be specified as a string or as a TargetTasks object. Args: d (dict): dictionary containing the target property information Example: >>> TargetProperty.fromDict({"name": "property_name", "task": "regression"}) TargetProperty(name=property_name, task=REGRESSION) Returns: TargetProperty: TargetProperty object """ if isinstance(d["task"], str): return TargetProperty( **{ k: TargetTasks[v.upper()] if k == "task" else v for k, v in d.items() } ) else: return TargetProperty(**d)
[docs] @classmethod def fromList(cls, _list: list[dict]): """Create a list of TargetProperty objects from a list of dictionaries. Args: _list (list): list of dictionaries containing the target property information Returns: list[TargetProperty]: list of TargetProperty objects """ return [cls.fromDict(d) for d in _list]
[docs] @staticmethod def toList(_list: list, task_as_str: bool = False, drop_transformer: bool = True): """Convert a list of TargetProperty objects to a list of dictionaries. Args: _list (list): list of TargetProperty objects task_as_str (bool): whether to convert the task to a string Returns: list[dict]: list of dictionaries containing the target property information """ target_props = [] for target_prop in _list: target_props.append( { "name": target_prop.name, "task": target_prop.task.name if task_as_str else target_prop.task, } ) if target_prop.task.isClassification(): target_props[-1].update( {"th": target_prop.th, "n_classes": target_prop.nClasses} ) if not drop_transformer: target_props[-1].update({"transformer": target_prop.transformer}) return target_props
[docs] @staticmethod def selectFromList(_list: list, names: list): """Select a subset of TargetProperty objects from a list of TargetProperty objects. Args: _list (list): list of TargetProperty objects names (list): list of names of the target properties to be selected original_names (bool): whether to use the original names of the target properties Returns: list[TargetProperty]: list of TargetProperty objects """ return [t for t in _list if t.name in names]
[docs] @staticmethod def getNames(_list: list): """Get the names of the target properties from a list of TargetProperty objects. Args: _list (list): list of TargetProperty objects Returns: list[str]: list of names of the target properties """ return [t.name for t in _list]