Source code for qsprpred.plotting.classification

"""Plotting functions for classification models."""
import os.path
import re
from abc import ABC
from copy import deepcopy
from typing import List, Literal

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.calibration import CalibrationDisplay
from sklearn.metrics import (
    PrecisionRecallDisplay,
    RocCurveDisplay,
    accuracy_score,
    auc,
    f1_score,
    matthews_corrcoef,
    precision_score,
    recall_score,
    roc_auc_score,
    confusion_matrix,
)

from ..models.assessment.metrics.classification import CalibrationError
from ..models.model import QSPRModel
from ..plotting.base_plot import ModelPlot
from ..tasks import ModelTasks


[docs]class ClassifierPlot(ModelPlot, ABC): """Base class for plots of classification models."""
[docs] def getSupportedTasks(self) -> List[ModelTasks]: """Return a list of tasks supported by this plotter.""" return [ ModelTasks.SINGLECLASS, ModelTasks.MULTICLASS, ModelTasks.MULTITASK_SINGLECLASS, ModelTasks.MULTITASK_MULTICLASS, ]
[docs] def prepareAssessment(self, assessment_df: pd.DataFrame) -> pd.DataFrame: """Prepare assessment dataframe for plotting Args: assessment_df (pd.DataFrame): the assessment dataframe containing the experimental and predicted values for each property. The dataframe should have the following columns: QSPRID, Fold (opt.), <property_name>_<suffixes>_<Label/Prediction/ProbabilityClass_X> Returns: pd.DataFrame: The dataframe containing the assessment results, columns: QSPRID, Fold, Property, Label, Prediction, Class, Set """ # change all property columns into one column id_vars = ["QSPRID", "Fold"] if "Fold" in assessment_df.columns else ["QSPRID"] df = assessment_df.melt(id_vars=id_vars) # split the variable (<property_name>_<suffixes>_<Label/Prediction/ProbabilityClass_X>) column # into the property name and the type (Label or Prediction or ProbabilityClass_X) pattern = re.compile( r"^(?P<Property>.*?)_(?P<type>Label|Prediction|ProbabilityClass_\d+)$" ) df[["Property", "type"]] = df["variable"].apply( lambda x: pd.Series(pattern.match(x).groupdict()) ) df.drop("variable", axis=1, inplace=True) # pivot the dataframe so that Label and Prediction are separate columns df = df.pivot_table( index=[*id_vars, "Property"], columns="type", values="value" ) df.reset_index(inplace=True) df.columns.name = None df["Label"] = df["Label"].astype(int) df["Prediction"] = df["Prediction"].astype(int) # Add Fold column if it doesn't exist (for independent test set) if "Fold" not in df.columns: df["Fold"] = "Independent Test" df["Set"] = "Independent Test" else: df["Set"] = "Cross Validation" return df
[docs] def prepareClassificationResults(self) -> pd.DataFrame: """Prepare classification results dataframe for plotting. Returns: pd.DataFrame: the dataframe containing the classficiation results, columns: Model, QSPRID, Fold, Property, Label, Prediction, Set """ model_results = {} for m, model in enumerate(self.models): # Read in and prepare the cross-validation and independent test set results df_cv = self.prepareAssessment(pd.read_table(self.cvPaths[model])) df_ind = self.prepareAssessment(pd.read_table(self.indPaths[model])) # concatenate the cross-validation and independent test set results df = pd.concat([df_cv, df_ind]) print(model.name) model_results[model.name] = df # concatenate the results from all models and add the model name as a column df = ( pd.concat( model_results.values(), keys=model_results.keys(), names=["Model"] ) .reset_index(level=1, drop=True) .reset_index() ) self.results = df return df
[docs] def calculateMultiClassMetrics(self, df, average_type, n_classes): """Calculate metrics for a given dataframe.""" # check if ProbabilityClass_X columns exist proba = all([f"ProbabilityClass_{i}" in df.columns for i in range(n_classes)]) if average_type == "All": metrics = { "accuracy": accuracy_score(df.Label, df.Prediction), "matthews_corrcoef": matthews_corrcoef(df.Label, df.Prediction), } if proba: # convert y_pred to a list of arrays with shape (n_samples, n_classes) y_pred = [ df[[f"ProbabilityClass_{i}" for i in range(n_classes)]].values] metrics["calibration_error"] = CalibrationError()( df.Label.values, y_pred ) return pd.Series(metrics) # check if average type is int (i.e. class number, so we can calculate metrics for a single class) if isinstance(average_type, int): class_num = average_type average_type = None metrics = { "precision": precision_score(df.Label, df.Prediction, average=average_type), "recall": recall_score(df.Label, df.Prediction, average=average_type), "f1": f1_score(df.Label, df.Prediction, average=average_type), } if proba: metrics["roc_auc_ovr"] = roc_auc_score( df.Label, df[[f"ProbabilityClass_{i}" for i in range(n_classes)]], multi_class="ovr", average=average_type, ) # FIXME: metrics are only returned for class "class_num", but calculated for all classes # as returning a list of metrics for each class gives a dataframe with lists as values # which is difficult to explode. Need to find a better way to do this. if average_type is None: metrics = {k: v[class_num] for k, v in metrics.items()} # Conditionally include roc_auc_ovo for non-micro averages if average_type != "micro" and average_type is not None and proba: metrics["roc_auc_ovo"] = roc_auc_score( df.Label, df[[f"ProbabilityClass_{i}" for i in range(n_classes)]], multi_class="ovo", average=average_type, ) return pd.Series(metrics)
[docs] def calculateSingleClassMetrics(self, df): """Calculate metrics for a given dataframe.""" # check if ProbabilityClass_1 column exists proba = "ProbabilityClass_1" in df.columns metrics = { "accuracy": accuracy_score(df.Label, df.Prediction), "precision": precision_score(df.Label, df.Prediction), "recall": recall_score(df.Label, df.Prediction), "f1": f1_score(df.Label, df.Prediction), "matthews_corrcoef": matthews_corrcoef(df.Label, df.Prediction), } if proba: # convert y_pred to a list of arrays with shape (n_samples, 2) y_pred = [ np.column_stack([1 - df.ProbabilityClass_1, df.ProbabilityClass_1])] metrics["calibration_error"] = CalibrationError()(df.Label.values, y_pred) metrics["roc_auc"] = roc_auc_score(df.Label, df.ProbabilityClass_1) return pd.Series(metrics)
[docs] def getSummary(self): """Get summary statistics for classification results.""" if not hasattr(self, "results"): self.prepareClassificationResults() df = deepcopy(self.results) summary_list = {} # make summary for each model and property for model_name in df.Model.unique(): for property_name in df.Property.unique(): df_subset = df[ (df.Model == model_name) & (df.Property == property_name) ] # get the number of classes n_classes = df_subset["Label"].nunique() # calculate metrics for binary and multi-class properties if n_classes == 2: summary_list[f"{model_name}_{property_name}_Binary"] = ( df_subset.groupby(["Model", "Fold", "Property"]).apply( lambda x: self.calculateSingleClassMetrics(x) ) ).reset_index() summary_list[f"{model_name}_{property_name}_Binary"][ "Class" ] = "Binary" else: # calculate metrics for each class, average type and non-average type metrics class_list = [ *["macro", "micro", "weighted", "All"], *list(range(n_classes)), ] for class_type in class_list: summary_list[f"{model_name}_{property_name}_{class_type}"] = ( df_subset.groupby(["Model", "Fold", "Property"]).apply( lambda x: self.calculateMultiClassMetrics( x, class_type, n_classes ) ) ).reset_index() summary_list[f"{model_name}_{property_name}_{class_type}"][ "Class" ] = class_type df_summary = pd.concat(summary_list.values(), ignore_index=True) df_summary["Set"] = df_summary["Fold"].apply( lambda x: "Independent Test" if x == "Independent Test" else "Cross Validation" ) self.summary = df_summary return df_summary
[docs]class ROCPlot(ClassifierPlot): """Plot of ROC-curve (receiver operating characteristic curve) for a given classification model. """
[docs] def getSupportedTasks(self) -> List[ModelTasks]: """Return a list of tasks supported by this plotter.""" return [ModelTasks.SINGLECLASS, ModelTasks.MULTITASK_SINGLECLASS]
[docs] def makeCV(self, model: QSPRModel, property_name: str) -> plt.Axes: """Make the plot for a given model using cross-validation data. Many thanks to the scikit-learn documentation since the code below borrows heavily from the example at: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html Args: model (QSPRModel): the model to plot the data from. property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). Returns: ax (matplotlib.axes.Axes): the axes object containing the plot. """ df = pd.read_table(self.cvPaths[model]) # get true positive rate and false positive rate for each fold tprs = [] aucs = [] mean_fpr = np.linspace(0, 1, 100) ax = plt.gca() for fold in df.Fold.unique(): # get labels y_pred = df[f"{property_name}_ProbabilityClass_1"][df.Fold == fold] y_true = df[f"{property_name}_Label"][df.Fold == fold] # do plotting viz = RocCurveDisplay.from_predictions( y_true, y_pred, name="ROC fold {}".format(fold + 1), alpha=0.3, lw=1, ax=ax, ) interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr) interp_tpr[0] = 0.0 tprs.append(interp_tpr) aucs.append(viz.roc_auc) # plot chance line ax.plot( [0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8 ) # plot mean ROC across folds mean_tpr = np.mean(tprs, axis=0) mean_tpr[-1] = 1.0 mean_auc = auc(mean_fpr, mean_tpr) std_auc = np.std(aucs) ax.plot( mean_fpr, mean_tpr, color="b", label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc), lw=2, alpha=0.8, ) # plot standard deviation across folds std_tpr = np.std(tprs, axis=0) tprs_upper = np.minimum(mean_tpr + std_tpr, 1) tprs_lower = np.maximum(mean_tpr - std_tpr, 0) ax.fill_between( mean_fpr, tprs_lower, tprs_upper, color="grey", alpha=0.2, label=r"$\pm$ 1 std. dev.", ) # set axes limits and labels ax.set( xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Receiver Operating Characteristic ({self.modelNames[model]})", ) ax.legend(loc="lower right") return ax
[docs] def makeInd(self, model: QSPRModel, property_name: str) -> plt.Axes: """Make the ROC plot for a given model using independent test data. Args: model (QSPRModel): the model to plot the data from. property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). Returns: ax (matplotlib.axes.Axes): the axes object containing the plot. """ df = pd.read_table(self.indPaths[model]) y_pred = df[f"{property_name}_ProbabilityClass_1"] y_true = df[f"{property_name}_Label"] ax = plt.gca() RocCurveDisplay.from_predictions( y_true, y_pred, name="ROC", ax=ax, ) ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance") ax.set( xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Receiver Operating Characteristic ({self.modelNames[model]})", ) ax.legend(loc="lower right") return ax
[docs] def make( self, save: bool = True, show: bool = False, property_name: str | None = None, validation: str = "cv", fig_size: tuple = (6, 6), ) -> list[plt.Axes]: """Make the ROC plot for given validation sets. Args: property_name (str): name of the predicted property to plot (should correspond to the prefix of the column names in `cvPaths` or `indPaths` files). If `None`, the first property in the model's `targetProperties` list will be used. validation (str): The type of validation set to read data for. Can be either 'cv' for cross-validation or 'ind' for independent test set. fig_size (tuple): The size of the figure to create. save (bool): Whether to save the plot to a file. show (bool): Whether to display the plot. Returns: axes (list[plt.Axes]): A list of matplotlib axes objects containing the plots. """ if property_name is None: property_name = self.models[0].targetProperties[0].name # fetch the correct plotting function based on validation set type # and make the plot for each model choices = {"cv": self.makeCV, "ind": self.makeInd} axes = [] for model in self.models: fig, ax = plt.subplots(figsize=fig_size) choices[validation](model, property_name) axes.append(fig) if save: fig.savefig(f"{self.modelOuts[model]}.{validation}.png") if show: plt.show() plt.clf() return axes
[docs]class PRCPlot(ClassifierPlot): """Plot of Precision-Recall curve for a given model."""
[docs] def getSupportedTasks(self) -> List[ModelTasks]: """Return a list of tasks supported by this plotter.""" return [ModelTasks.SINGLECLASS, ModelTasks.MULTITASK_SINGLECLASS]
[docs] def makeCV(self, model: QSPRModel, property_name: str) -> plt.Axes: """Make the plot for a given model using cross-validation data. Args: model (QSPRModel): the model to plot the data from. property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). Returns: ax (matplotlib.axes.Axes): the axes object containing the plot. """ # read data from file for each fold df = pd.read_table(self.cvPaths[model]) y_real = [] y_predproba = [] ax = plt.gca() for fold in df.Fold.unique(): # get labels y_pred = df[f"{property_name}_ProbabilityClass_1"][df.Fold == fold] y_true = df[f"{property_name}_Label"][df.Fold == fold] y_predproba.append(y_pred) y_real.append(y_true) # do plotting PrecisionRecallDisplay.from_predictions( y_true, y_pred, name="PRC fold {}".format(fold + 1), ax=ax, alpha=0.3, lw=1, ) # Linear interpolation of PR curve is not recommended, so we don't plot "chance" # https://dl.acm.org/doi/10.1145/1143844.1143874 # Plotting the average precision-recall curve over the cross validation runs y_real = np.concatenate(y_real) y_predproba = np.concatenate(y_predproba) PrecisionRecallDisplay.from_predictions( y_real, y_predproba, name="Mean PRC", color="b", ax=ax, lw=1.2, alpha=0.8, ) ax.set( xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Precision-Recall Curve ({self.modelNames[model]})", ) ax.legend(loc="best") return ax
[docs] def makeInd(self, model: QSPRModel, property_name: str) -> plt.Axes: """Make the plot for a given model using independent test data. Args: model (QSPRModel): the model to plot the data from. property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). Returns: ax (matplotlib.axes.Axes): the axes object containing the plot. """ # read data from file df = pd.read_table(self.indPaths[model]) y_pred = df[f"{property_name}_ProbabilityClass_1"] y_true = df[f"{property_name}_Label"] # do plotting ax = plt.gca() PrecisionRecallDisplay.from_predictions( y_true, y_pred, name="PRC", ax=ax, ) ax.set( xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Receiver Operating Characteristic ({self.modelNames[model]})", ) ax.legend(loc="best") return ax
[docs] def make( self, save: bool = True, show: bool = False, property_name: str | None = None, validation: str = "cv", fig_size: tuple = (6, 6), ): """Make the plot for a given validation type. Args: property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). If `None`, the first property in the model's `targetProperties` list will be used. validation (str): The type of validation data to use. Can be either 'cv' for cross-validation or 'ind' for independent test set. fig_size (tuple): The size of the figure to create. save (bool): Whether to save the plot to a file. show (bool): Whether to display the plot. Returns: axes (list): A list of matplotlib axes objects containing the plots. """ if property_name is None: property_name = self.models[0].targetProperties[0].name choices = {"cv": self.makeCV, "ind": self.makeInd} axes = [] for model in self.models: fig, ax = plt.subplots(figsize=fig_size) ax = choices[validation](model, property_name) axes.append(ax) if save: fig.savefig(f"{self.modelOuts[model]}.{validation}.png") if show: plt.show() plt.clf() return axes
[docs]class CalibrationPlot(ClassifierPlot): """Plot of calibration curve for a given model."""
[docs] def getSupportedTasks(self) -> List[ModelTasks]: """Return a list of tasks supported by this plotter.""" return [ModelTasks.SINGLECLASS, ModelTasks.MULTITASK_SINGLECLASS]
[docs] def makeCV( self, model: QSPRModel, property_name: str, n_bins: int = 10 ) -> plt.Axes: """Make the plot for a given model using cross-validation data. Args: model (QSPRModel): the model to plot the data from. property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). n_bins (int): The number of bins to use for the calibration curve. Returns: ax (matplotlib.axes.Axes): the axes object containing the plot. """ # read data from file for each fold and plot df = pd.read_table(self.cvPaths[model]) y_real = [] y_pred_proba = [] ax = plt.gca() for fold in df.Fold.unique(): # get labels y_pred = df[f"{property_name}_ProbabilityClass_1"][df.Fold == fold] y_true = df[f"{property_name}_Label"][df.Fold == fold] y_pred_proba.append(y_pred) y_real.append(y_true) # do plotting CalibrationDisplay.from_predictions( y_true, y_pred, n_bins=n_bins, name="Fold: {}".format(fold + 1), ax=ax, alpha=0.3, lw=1, ) # Plotting the average precision-recall curve over the cross validation runs y_real = np.concatenate(y_real) y_pred_proba = np.concatenate(y_pred_proba) CalibrationDisplay.from_predictions( y_real, y_pred_proba, n_bins=n_bins, name="Mean", color="b", ax=ax, lw=1.2, alpha=0.8, ) ax.set( xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Calibration Curve ({self.modelNames[model]})", ) ax.legend(loc="best") return ax
[docs] def makeInd( self, model: QSPRModel, property_name: str, n_bins: int = 10 ) -> plt.Axes: """Make the plot for a given model using independent test data. Args: model (QSPRModel): the model to plot the data from. property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). n_bins (int): The number of bins to use for the calibration curve. Returns: ax (matplotlib.axes.Axes): the axes object containing the plot. """ df = pd.read_table(self.indPaths[model]) y_pred = df[f"{property_name}_ProbabilityClass_1"] y_true = df[f"{property_name}_Label"] ax = plt.gca() CalibrationDisplay.from_predictions( y_true, y_pred, n_bins=n_bins, name="Calibration", ax=ax, ) ax.set( xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title=f"Calibration Curve ({self.modelNames[model]})", ) ax.legend(loc="best") return ax
[docs] def make( self, save: bool = True, show: bool = False, property_name: str | None = None, validation: str = "cv", fig_size: tuple = (6, 6), ) -> list[plt.Axes]: """Make the plot for a given validation type. Args: property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). If `None`, the first property in the model's `targetProperties` list will be used. validation (str): The type of validation data to use. Can be either 'cv' for cross-validation or 'ind' for independent test set. fig_size (tuple): The size of the figure to create. save (bool): Whether to save the plot to a file. show (bool): Whether to display the plot. Returns: axes (list[plt.Axes]): A list of matplotlib axes objects containing the plots. """ if property_name is None: property_name = self.models[0].targetProperties[0].name choices = {"cv": self.makeCV, "ind": self.makeInd} axes = [] for model in self.models: fig, ax = plt.subplots(figsize=fig_size) ax = choices[validation](model, property_name, fig_size) axes.append(ax) if save: fig.savefig(f"{self.modelOuts[model]}.{validation}.png") if show: plt.show() plt.clf() return axes
[docs]class MetricsPlot(ClassifierPlot): """Plot of metrics for a given model. Attributes: models (list): A list of QSPRModel objects to plot the data from. metrics (list): A list of metrics to plot, choose from: f1, matthews_corrcoef, precision, recall, accuracy, roc_auc, roc_auc_ovr, roc_auc_ovo and calibration_error """ def __init__( self, models: List[QSPRModel], metrics: List[ Literal[ "f1", "matthews_corrcoef", "precision", "recall", "accuracy", "calibration_error", "roc_auc", "roc_auc_ovr", "roc_auc_ovo", ] ] = [ "f1", "matthews_corrcoef", "precision", "recall", "accuracy", "calibration_error", "roc_auc", "roc_auc_ovr", "roc_auc_ovo", ], ): """Initialise the metrics plot. Args: models (list): A list of QSPRModel objects to plot the data from. metrics (list): A list of metrics to plot. """ super().__init__(models) self.metrics = metrics
[docs] def make( self, save: bool = True, show: bool = False, out_path: str | None = None, ) -> tuple[List[sns.FacetGrid], pd.DataFrame]: """Make the plot for a given validation type. Args: property_name (str): name of the property to plot (should correspond to the prefix of the column names in the data files). save (bool): Whether to save the plot to a file. show (bool): Whether to display the plot. out_path (str | None): Path to save the plots to, e.g. "results/plot.png", the plot will be saved to this path with the metric name appended before the extension, e.g. "results/plot_roc_auc.png". If `None`, the plots will be saved to each model's output directory. Returns: figures (list[sns.FacetGrid]): the seaborn FacetGrid objects used to make the plot pd.DataFrame: A dataframe containing the summary data generated. """ # Get summary with calculated metrics if not hasattr(self, "summary"): self.getSummary() figures = [] for metric in self.metrics: # check if metric in summary dataframe if metric not in self.summary.columns: print(f"Metric {metric} not in summary dataframe, skipping") continue # plot the results g = sns.catplot( self.summary, x="Class", y=metric, hue="Set", col="Property", row="Model", kind="bar", margin_titles=True, sharex=False, sharey=False, ) # set y range max to 1 for each plot, but don't set min to 0 as some metrics can be negative for ax in g.axes_dict.values(): y_min, y_max = ax.get_ylim() ax.set_ylim(y_min, 1) figures.append(g) if save: # add metric to out_path if it exists before the extension if out_path is not None: plt.savefig( os.path.splitext(out_path)[0] + f"_{metric}.png", dpi=300 ) else: for model in self.models: plt.savefig(f"{model.outPrefix}_{metric}.png", dpi=300) if show: plt.show() plt.clf() return figures, self.summary
[docs]class ConfusionMatrixPlot(ClassifierPlot): """Plot of confusion matrix for a given model as a heatmap."""
[docs] def getConfusionMatrixDict(self, df: pd.DataFrame) -> dict: """Create dictionary of confusion matrices for each model, property and fold Args: df (pd.DataFrame): the dataframe containing the classficiation results, columns: Model, QSPRID, Fold, Property, Label, Prediction, Set Returns: dict: dictionary of confusion matrices for each model, property and fold """ conf_dict = {} for model in df.Model.unique(): for property in df.Property.unique(): for fold in df.Fold.unique(): df_subset = df[ (df.Model == model) & (df.Property == property) & (df.Fold == fold) ] conf_dict[(model, property, fold)] = confusion_matrix( df_subset.Label, df_subset.Prediction ) return conf_dict
[docs] def make( self, save: bool = True, show: bool = False, out_path: str | None = None, ) -> tuple[dict, plt.Axes]: """Make confusion matrix heatmap for each model, property and fold Args: save (bool): whether to save the plot show (bool): whether to show the plot out_path (str | None): path to save the plot to, e.g. "results/plot.png", the plots will be saved to this path with the plot identifier appended before the extension, If `None`, the plots will be saved to each model's output directory. Returns: dict: dictionary of confusion matrices for each model, property and fold list[plt.axes.Axes]: a list of matplotlib axes objects containing the plots. """ df = self.prepareClassificationResults() # Get dictionary of confusion matrices conf_dict = self.getConfusionMatrixDict(df) # Create heatmap for each model, property and fold axes = [] for model in df.Model.unique(): for property in df.Property.unique(): for fold in df.Fold.unique(): fig, ax = plt.subplots() sns.heatmap( conf_dict[(model, property, fold)], annot=True, fmt="g", cmap="Blues", ) ax.set_title(f"Confusion Matrix ({model}_{property}_fold_{fold})") ax.set_xlabel("Predicted label") ax.set_ylabel("True label") axes.append(fig) if save: if out_path is not None: # add identifier to out_path before the extension plt.savefig( os.path.splitext(out_path)[0] + f"_{model}_{property}_{fold}_confusion_matrix.png", dpi=300, ) else: # reverse self.modelNames dictionary to get model out modelNames = {v: k for k, v in self.modelNames.items()} plt.savefig( f"{self.modelOuts[modelNames[model]]}_{property}_{fold}_confusion_matrix.png", dpi=300, ) if show: plt.show() plt.clf() plt.close() return axes, conf_dict