#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-10-02 18:46:00 (ywatanabe)"
# File: /home/ywatanabe/proj/scitex_repo/src/scitex/ml/plt/plot_optuna_study.py
# ----------------------------------------
from __future__ import annotations
import os
import pathlib
import scitex_io
from scitex_dev import try_import_optional
# `figrecipe` is the underlying impl that the umbrella exposes as
# `scitex.plt` — `ax`, `configure_mpl`, `_subplots`, etc. live there.
# Use the peer standalone directly (PA304-clean); fall back to None
# when the optional plotting layer isn't installed.
_umbrella_plt = try_import_optional("figrecipe", extra="plt", pkg="scitex-ml")
__FILE__ = __file__
__DIR__ = os.path.dirname(__FILE__)
# ----------------------------------------
"""
Functionalities:
- Loads Optuna study and generates various visualizations
- Creates optimization history, parameter importances, slice plots
- Saves study history and visualization results
Dependencies:
- packages:
- optuna
- pandas
- scitex
IO:
- input-files:
- Optuna study database (.db file)
- output-files:
- study_history.csv
- optimization_history.png/html
- param_importances.png/html
- slice.png/html
- contour.png/html
- parallel_coordinate.png/html
"""
"""Imports"""
import argparse
import scitex_logging as logging
logger = logging.getLogger(__name__)
[docs]
def plot_optuna_study(lpath, value_str, sort=False):
"""
Loads an Optuna study and generates various visualizations for each target metric.
Parameters:
- lpath (str): Path to the Optuna study database.
- value_str (str): The name of the column to be used as the optimization target.
Returns:
- None
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import optuna
import pandas as pd
plt, CC = _umbrella_plt.configure_mpl(plt, fig_scale=3)
lpath = lpath.replace("./", "/")
study = optuna.load_study(study_name=None, storage=lpath)
sdir = lpath.replace("sqlite:///", "./").replace(".db", "/")
# To get the best trial:
best_trial = study.best_trial
print(f"Best trial number: {best_trial.number}")
print(f"Best trial value: {best_trial.value}")
print(f"Best trial parameters: {best_trial.params}")
print(f"Best trial user attributes: {best_trial.user_attrs}")
# Merge the user attributes into the study history DataFrame
study_history = study.trials_dataframe().rename(columns={"value": value_str})
if sort:
ascending = "MINIMIZE" in str(study.directions[0]) # [REVISED]
study_history = study_history.sort_values([value_str], ascending=ascending)
# Add user attributes to the study history DataFrame
attrs_df = []
for trial in study.trials:
user_attrs = trial.user_attrs
user_attrs = {k: v for k, v in user_attrs.items()}
attrs_df.append({"number": trial.number, **user_attrs})
attrs_df = pd.DataFrame(attrs_df).set_index("number")
# Updates study history
study_history = study_history.merge(
attrs_df, left_index=True, right_index=True, how="left"
).set_index("number")
try:
# Move SDIR column to position 1 — equivalent of scitex.gen.mv_col.
if "SDIR" in study_history.columns:
cols = list(study_history.columns)
cols.remove("SDIR")
cols.insert(1, "SDIR")
study_history = study_history[cols]
study_history["SDIR"] = study_history["SDIR"].apply(
lambda x: str(x).replace("RUNNING", "FINISHED")
)
best_trial_dir = study_history["SDIR"].iloc[0]
# Create symlink — equivalent of scitex.gen.symlink(force=True).
target = pathlib.Path(sdir + "best_trial")
if target.exists() or target.is_symlink():
target.unlink()
target.parent.mkdir(parents=True, exist_ok=True)
target.symlink_to(best_trial_dir)
except Exception as e:
print(e)
scitex_io.save(study_history, sdir + "study_history.csv", use_caller_path=True)
print(study_history)
# To visualize the optimization history:
fig = optuna.visualization.plot_optimization_history(study, target_name=value_str)
scitex_io.save(fig, sdir + "optimization_history.png", use_caller_path=True)
scitex_io.save(fig, sdir + "optimization_history.html", use_caller_path=True)
plt.close()
# To visualize the parameter importances:
fig = optuna.visualization.plot_param_importances(study, target_name=value_str)
scitex_io.save(fig, sdir + "param_importances.png", use_caller_path=True)
scitex_io.save(fig, sdir + "param_importances.html", use_caller_path=True)
plt.close()
# To visualize the slice of the study:
fig = optuna.visualization.plot_slice(study, target_name=value_str)
scitex_io.save(fig, sdir + "slice.png", use_caller_path=True)
scitex_io.save(fig, sdir + "slice.html", use_caller_path=True)
plt.close()
# To visualize the contour plot of the study:
fig = optuna.visualization.plot_contour(study, target_name=value_str)
scitex_io.save(fig, sdir + "contour.png", use_caller_path=True)
scitex_io.save(fig, sdir + "contour.html", use_caller_path=True)
plt.close()
# To visualize the parallel coordinate plot of the study:
fig = optuna.visualization.plot_parallel_coordinate(study, target_name=value_str)
scitex_io.save(fig, sdir + "parallel_coordinate.png", use_caller_path=True)
scitex_io.save(fig, sdir + "parallel_coordinate.html", use_caller_path=True)
plt.close()
# Keep backward compatibility
optuna_study = plot_optuna_study
"""Functions & Classes"""
def main(args):
"""
Demonstrate Optuna study visualization.
"""
# Example: Would require actual Optuna study database
logger.info("This script requires an existing Optuna study database.")
logger.info("Usage example:")
logger.info(' lpath = "sqlite:///path/to/optuna_study.db"')
logger.info(' plot_optuna_study(lpath, "Validation bACC", sort=True)')
return 0
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Visualize Optuna study results")
parser.add_argument(
"--lpath",
type=str,
default=None,
help="Path to Optuna study database (e.g., sqlite:///study.db)",
)
parser.add_argument(
"--value_str",
type=str,
default="value",
help="Target metric name (default: %(default)s)",
)
parser.add_argument(
"--sort",
action="store_true",
default=False,
help="Sort study history by target metric (default: %(default)s)",
)
args = parser.parse_args()
return args
def run_main() -> None:
"""Initialize scitex framework, run main function, and cleanup."""
global CONFIG, CC, sys, plt, rng
import sys
import matplotlib.pyplot as plt
import scitex as stx
args = parse_args()
CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start(
sys,
plt,
args=args,
file=__FILE__,
sdir_suffix=None,
verbose=False,
agg=True,
)
if args.lpath:
plot_optuna_study(args.lpath, args.value_str, args.sort)
exit_status = 0
else:
exit_status = main(args)
stx.session.close(
CONFIG,
verbose=False,
notify=False,
message="",
exit_status=exit_status,
)
if __name__ == "__main__":
run_main()
# EOF