1import math 2from typing import Callable 3from typing import cast 4from typing import List 5from typing import Optional 6from typing import Tuple 7from typing import Union 8 9from optuna._experimental import experimental 10from optuna.logging import get_logger 11from optuna.study import Study 12from optuna.trial import FrozenTrial 13from optuna.trial import TrialState 14from optuna.visualization._utils import _check_plot_args 15from optuna.visualization.matplotlib._matplotlib_imports import _imports 16from optuna.visualization.matplotlib._utils import _is_log_scale 17from optuna.visualization.matplotlib._utils import _is_numerical 18 19 20if _imports.is_successful(): 21 from optuna.visualization.matplotlib._matplotlib_imports import Axes 22 from optuna.visualization.matplotlib._matplotlib_imports import Colormap 23 from optuna.visualization.matplotlib._matplotlib_imports import matplotlib 24 from optuna.visualization.matplotlib._matplotlib_imports import PathCollection 25 from optuna.visualization.matplotlib._matplotlib_imports import plt 26 27_logger = get_logger(__name__) 28 29 30@experimental("2.2.0") 31def plot_slice( 32 study: Study, 33 params: Optional[List[str]] = None, 34 *, 35 target: Optional[Callable[[FrozenTrial], float]] = None, 36 target_name: str = "Objective Value", 37) -> "Axes": 38 """Plot the parameter relationship as slice plot in a study with Matplotlib. 39 40 .. seealso:: 41 Please refer to :func:`optuna.visualization.plot_slice` for an example. 42 43 Example: 44 45 The following code snippet shows how to plot the parameter relationship as slice plot. 46 47 .. plot:: 48 49 import optuna 50 51 52 def objective(trial): 53 x = trial.suggest_float("x", -100, 100) 54 y = trial.suggest_categorical("y", [-1, 0, 1]) 55 return x ** 2 + y 56 57 58 sampler = optuna.samplers.TPESampler(seed=10) 59 study = optuna.create_study(sampler=sampler) 60 study.optimize(objective, n_trials=10) 61 62 optuna.visualization.matplotlib.plot_slice(study, params=["x", "y"]) 63 64 Args: 65 study: 66 A :class:`~optuna.study.Study` object whose trials are plotted for their target values. 67 params: 68 Parameter list to visualize. The default is all parameters. 69 target: 70 A function to specify the value to display. If it is :obj:`None` and ``study`` is being 71 used for single-objective optimization, the objective values are plotted. 72 73 .. note:: 74 Specify this argument if ``study`` is being used for multi-objective optimization. 75 target_name: 76 Target's name to display on the axis label. 77 78 79 Returns: 80 A :class:`matplotlib.axes.Axes` object. 81 82 Raises: 83 :exc:`ValueError`: 84 If ``target`` is :obj:`None` and ``study`` is being used for multi-objective 85 optimization. 86 """ 87 88 _imports.check() 89 _check_plot_args(study, target, target_name) 90 return _get_slice_plot(study, params, target, target_name) 91 92 93def _get_slice_plot( 94 study: Study, 95 params: Optional[List[str]] = None, 96 target: Optional[Callable[[FrozenTrial], float]] = None, 97 target_name: str = "Objective Value", 98) -> "Axes": 99 100 # Calculate basic numbers for plotting. 101 trials = [trial for trial in study.trials if trial.state == TrialState.COMPLETE] 102 103 if len(trials) == 0: 104 _logger.warning("Your study does not have any completed trials.") 105 _, ax = plt.subplots() 106 return ax 107 108 all_params = {p_name for t in trials for p_name in t.params.keys()} 109 if params is None: 110 sorted_params = sorted(all_params) 111 else: 112 for input_p_name in params: 113 if input_p_name not in all_params: 114 raise ValueError("Parameter {} does not exist in your study.".format(input_p_name)) 115 sorted_params = sorted(set(params)) 116 117 n_params = len(sorted_params) 118 119 # Set up the graph style. 120 cmap = plt.get_cmap("Blues") 121 padding_ratio = 0.05 122 plt.style.use("ggplot") # Use ggplot style sheet for similar outputs to plotly. 123 124 # Prepare data. 125 if target is None: 126 obj_values = [cast(float, t.value) for t in trials] 127 else: 128 obj_values = [target(t) for t in trials] 129 130 if n_params == 1: 131 # Set up the graph style. 132 fig, axs = plt.subplots() 133 axs.set_title("Slice Plot") 134 135 # Draw a scatter plot. 136 sc = _generate_slice_subplot( 137 trials, sorted_params[0], axs, cmap, padding_ratio, obj_values, target_name 138 ) 139 else: 140 # Set up the graph style. 141 min_figwidth = matplotlib.rcParams["figure.figsize"][0] / 2 142 fighight = matplotlib.rcParams["figure.figsize"][1] 143 # Ensure that each subplot has a minimum width without relying on auto-sizing. 144 fig, axs = plt.subplots( 145 1, n_params, sharey=True, figsize=(min_figwidth * n_params, fighight) 146 ) 147 fig.suptitle("Slice Plot") 148 149 # Draw scatter plots. 150 for i, param in enumerate(sorted_params): 151 ax = axs[i] 152 sc = _generate_slice_subplot( 153 trials, param, ax, cmap, padding_ratio, obj_values, target_name 154 ) 155 156 axcb = fig.colorbar(sc, ax=axs) 157 axcb.set_label("#Trials") 158 159 return axs 160 161 162def _generate_slice_subplot( 163 trials: List[FrozenTrial], 164 param: str, 165 ax: "Axes", 166 cmap: "Colormap", 167 padding_ratio: float, 168 obj_values: List[Union[int, float]], 169 target_name: str, 170) -> "PathCollection": 171 x_values = [] 172 y_values = [] 173 trial_numbers = [] 174 scale = None 175 for t, obj_v in zip(trials, obj_values): 176 if param in t.params: 177 x_values.append(t.params[param]) 178 y_values.append(obj_v) 179 trial_numbers.append(t.number) 180 ax.set(xlabel=param, ylabel=target_name) 181 if _is_log_scale(trials, param): 182 ax.set_xscale("log") 183 scale = "log" 184 elif not _is_numerical(trials, param): 185 x_values = [str(x) for x in x_values] 186 scale = "categorical" 187 xlim = _calc_lim_with_padding(x_values, padding_ratio, scale) 188 ax.set_xlim(xlim[0], xlim[1]) 189 sc = ax.scatter(x_values, y_values, c=trial_numbers, cmap=cmap, edgecolors="grey") 190 ax.label_outer() 191 192 return sc 193 194 195def _calc_lim_with_padding( 196 values: List[Union[int, float]], padding_ratio: float, scale: Optional[str] = None 197) -> Tuple[Union[int, float], Union[int, float]]: 198 value_max = max(values) 199 value_min = min(values) 200 if scale == "log": 201 padding = (math.log10(value_max) - math.log10(value_min)) * padding_ratio 202 return ( 203 math.pow(10, math.log10(value_min) - padding), 204 math.pow(10, math.log10(value_max) + padding), 205 ) 206 elif scale == "categorical": 207 width = len(set(values)) - 1 208 padding = width * padding_ratio 209 return -padding, width + padding 210 else: 211 padding = (value_max - value_min) * padding_ratio 212 return value_min - padding, value_max + padding 213