Source code for optimization.optuna_base_optimizer

from . import OptunaOptimizer

from typing import Callable, Dict
import json

import optuna
from optuna.trial import TrialState


[docs] class OptunaBaseOptimizer(OptunaOptimizer): """ Args: optimization_params (Dict): A dictionary containing the parameters to optimize. n_trials (int): The number of trials to run. Default is `100`. direction (str): The direction to optimize. Can be 'minimize' or 'maximize'. Default is `'minimize'`. pruner (optuna.pruners.BasePruner): The pruner to use. Default is `None`. save_dir (str): The directory to save the best parameters. Default is `None`. """ def __init__( self, optimization_params: Dict, n_trials: int = 100, direction: str = 'minimize', pruner: optuna.pruners.BasePruner = None, save_dir: str = None ): self.num_trials = n_trials self.direction = direction self.pruner = pruner self._optimization_params = optimization_params self.save_dir = save_dir @property def optimization_params(self) -> Dict: """ Get the optimization parameters. """ return self._optimization_params
[docs] def optimize( self, objective_function: Callable[[optuna.Trial], float], ) -> Dict: study = optuna.create_study(direction=self.direction, pruner=self.pruner) study.optimize(objective_function, n_trials=self.num_trials) best_params = study.best_params if self.save_dir is not None: file_path = str(self.save_dir) + '/best_params.json' json.dump(best_params, open(file_path, 'w')) print("Best parameters: ", best_params) self._print_optimization_report(study) return best_params
def _print_optimization_report(self, study): pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) completed_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) # Optimization report print("\nStudy statistics: ") print(" Number of finished trials: ", len(study.trials)) print(" Number of pruned trials: ", len(pruned_trials)) print(" Number of completed trials: ", len(completed_trials)) trial = study.best_trial print("Best trial:") print(" Value: ", trial.value) print(" Params: ") for key, value in trial.params.items(): print(" {}: {}".format(key, value)) print("\n")