Source code for drugex.utils.pareto

import numpy as np
import torch

[docs]def get_Pareto_fronts(scores): """Identify the Pareto fronts from a given set of scores. Parameters ---------- scores : numpy.ndarray An (n_points, n_scores) array of scores. Returns ------- list of numpy.ndarray A list containing the indices of points belonging to each Pareto front. """ # Initialize population_size = scores.shape[0] population_ids = np.arange(population_size) all_fronts = [] # Identify Pareto fronts while population_size > 0: # Identify the current Pareto front pareto_front = np.ones(population_size, dtype=bool) for i in range(population_size): for j in range(population_size): # Strictly j better than i in all scores (i dominated by j) # -> i not in Pareto front if all(scores[j] >= scores[i]) and any(scores[j] > scores[i]): pareto_front[i] = 0 break # Add the current Pareto front to the list of all fronts current_front_ids = population_ids[pareto_front] all_fronts.append(current_front_ids) # Remove the current Pareto front from consideration in future iterations scores = scores[~pareto_front] population_ids = population_ids[~pareto_front] population_size = scores.shape[0] return all_fronts