1import math
2from typing import Callable
3from typing import cast
4from typing import List
5from typing import Optional
6from typing import Tuple
7from typing import Union
8
9from optuna._experimental import experimental
10from optuna.logging import get_logger
11from optuna.study import Study
12from optuna.trial import FrozenTrial
13from optuna.trial import TrialState
14from optuna.visualization._utils import _check_plot_args
15from optuna.visualization.matplotlib._matplotlib_imports import _imports
16from optuna.visualization.matplotlib._utils import _is_log_scale
17from optuna.visualization.matplotlib._utils import _is_numerical
18
19
20if _imports.is_successful():
21    from optuna.visualization.matplotlib._matplotlib_imports import Axes
22    from optuna.visualization.matplotlib._matplotlib_imports import Colormap
23    from optuna.visualization.matplotlib._matplotlib_imports import matplotlib
24    from optuna.visualization.matplotlib._matplotlib_imports import PathCollection
25    from optuna.visualization.matplotlib._matplotlib_imports import plt
26
27_logger = get_logger(__name__)
28
29
30@experimental("2.2.0")
31def plot_slice(
32    study: Study,
33    params: Optional[List[str]] = None,
34    *,
35    target: Optional[Callable[[FrozenTrial], float]] = None,
36    target_name: str = "Objective Value",
37) -> "Axes":
38    """Plot the parameter relationship as slice plot in a study with Matplotlib.
39
40    .. seealso::
41        Please refer to :func:`optuna.visualization.plot_slice` for an example.
42
43    Example:
44
45        The following code snippet shows how to plot the parameter relationship as slice plot.
46
47        .. plot::
48
49            import optuna
50
51
52            def objective(trial):
53                x = trial.suggest_float("x", -100, 100)
54                y = trial.suggest_categorical("y", [-1, 0, 1])
55                return x ** 2 + y
56
57
58            sampler = optuna.samplers.TPESampler(seed=10)
59            study = optuna.create_study(sampler=sampler)
60            study.optimize(objective, n_trials=10)
61
62            optuna.visualization.matplotlib.plot_slice(study, params=["x", "y"])
63
64    Args:
65        study:
66            A :class:`~optuna.study.Study` object whose trials are plotted for their target values.
67        params:
68            Parameter list to visualize. The default is all parameters.
69        target:
70            A function to specify the value to display. If it is :obj:`None` and ``study`` is being
71            used for single-objective optimization, the objective values are plotted.
72
73            .. note::
74                Specify this argument if ``study`` is being used for multi-objective optimization.
75        target_name:
76            Target's name to display on the axis label.
77
78
79    Returns:
80        A :class:`matplotlib.axes.Axes` object.
81
82    Raises:
83        :exc:`ValueError`:
84            If ``target`` is :obj:`None` and ``study`` is being used for multi-objective
85            optimization.
86    """
87
88    _imports.check()
89    _check_plot_args(study, target, target_name)
90    return _get_slice_plot(study, params, target, target_name)
91
92
93def _get_slice_plot(
94    study: Study,
95    params: Optional[List[str]] = None,
96    target: Optional[Callable[[FrozenTrial], float]] = None,
97    target_name: str = "Objective Value",
98) -> "Axes":
99
100    # Calculate basic numbers for plotting.
101    trials = [trial for trial in study.trials if trial.state == TrialState.COMPLETE]
102
103    if len(trials) == 0:
104        _logger.warning("Your study does not have any completed trials.")
105        _, ax = plt.subplots()
106        return ax
107
108    all_params = {p_name for t in trials for p_name in t.params.keys()}
109    if params is None:
110        sorted_params = sorted(all_params)
111    else:
112        for input_p_name in params:
113            if input_p_name not in all_params:
114                raise ValueError("Parameter {} does not exist in your study.".format(input_p_name))
115        sorted_params = sorted(set(params))
116
117    n_params = len(sorted_params)
118
119    # Set up the graph style.
120    cmap = plt.get_cmap("Blues")
121    padding_ratio = 0.05
122    plt.style.use("ggplot")  # Use ggplot style sheet for similar outputs to plotly.
123
124    # Prepare data.
125    if target is None:
126        obj_values = [cast(float, t.value) for t in trials]
127    else:
128        obj_values = [target(t) for t in trials]
129
130    if n_params == 1:
131        # Set up the graph style.
132        fig, axs = plt.subplots()
133        axs.set_title("Slice Plot")
134
135        # Draw a scatter plot.
136        sc = _generate_slice_subplot(
137            trials, sorted_params[0], axs, cmap, padding_ratio, obj_values, target_name
138        )
139    else:
140        # Set up the graph style.
141        min_figwidth = matplotlib.rcParams["figure.figsize"][0] / 2
142        fighight = matplotlib.rcParams["figure.figsize"][1]
143        # Ensure that each subplot has a minimum width without relying on auto-sizing.
144        fig, axs = plt.subplots(
145            1, n_params, sharey=True, figsize=(min_figwidth * n_params, fighight)
146        )
147        fig.suptitle("Slice Plot")
148
149        # Draw scatter plots.
150        for i, param in enumerate(sorted_params):
151            ax = axs[i]
152            sc = _generate_slice_subplot(
153                trials, param, ax, cmap, padding_ratio, obj_values, target_name
154            )
155
156    axcb = fig.colorbar(sc, ax=axs)
157    axcb.set_label("#Trials")
158
159    return axs
160
161
162def _generate_slice_subplot(
163    trials: List[FrozenTrial],
164    param: str,
165    ax: "Axes",
166    cmap: "Colormap",
167    padding_ratio: float,
168    obj_values: List[Union[int, float]],
169    target_name: str,
170) -> "PathCollection":
171    x_values = []
172    y_values = []
173    trial_numbers = []
174    scale = None
175    for t, obj_v in zip(trials, obj_values):
176        if param in t.params:
177            x_values.append(t.params[param])
178            y_values.append(obj_v)
179            trial_numbers.append(t.number)
180    ax.set(xlabel=param, ylabel=target_name)
181    if _is_log_scale(trials, param):
182        ax.set_xscale("log")
183        scale = "log"
184    elif not _is_numerical(trials, param):
185        x_values = [str(x) for x in x_values]
186        scale = "categorical"
187    xlim = _calc_lim_with_padding(x_values, padding_ratio, scale)
188    ax.set_xlim(xlim[0], xlim[1])
189    sc = ax.scatter(x_values, y_values, c=trial_numbers, cmap=cmap, edgecolors="grey")
190    ax.label_outer()
191
192    return sc
193
194
195def _calc_lim_with_padding(
196    values: List[Union[int, float]], padding_ratio: float, scale: Optional[str] = None
197) -> Tuple[Union[int, float], Union[int, float]]:
198    value_max = max(values)
199    value_min = min(values)
200    if scale == "log":
201        padding = (math.log10(value_max) - math.log10(value_min)) * padding_ratio
202        return (
203            math.pow(10, math.log10(value_min) - padding),
204            math.pow(10, math.log10(value_max) + padding),
205        )
206    elif scale == "categorical":
207        width = len(set(values)) - 1
208        padding = width * padding_ratio
209        return -padding, width + padding
210    else:
211        padding = (value_max - value_min) * padding_ratio
212        return value_min - padding, value_max + padding
213