dashml

DashML — ML lifecycle management for Databricks: preprocessing, evaluation, drift monitoring, governance, and champion/challenger promotion.

 1"""DashML — ML lifecycle management for Databricks: preprocessing, evaluation,
 2drift monitoring, governance, and champion/challenger promotion."""
 3from dashml.config import CleaningSpec, DerivedColumn, EvaluationSpec, ModelSpec, MonitoringSpec, RunConfig
 4from dashml.evaluation import build_model_card, check_thresholds, explain_features, plot_feature_importance
 5from dashml.governance import build_governance_artifacts
 6from dashml.monitor import ModelMonitor
 7from dashml.preprocessing import clean_dataframe
 8from dashml.registry import RunTracker, promote_challenger, register_model
 9from dashml.ui import launch
10
11__version__ = "0.1.4"
12__all__ = [
13    "CleaningSpec",
14    "DerivedColumn",
15    "EvaluationSpec",
16    "ModelMonitor",
17    "ModelSpec",
18    "MonitoringSpec",
19    "RunConfig",
20    "RunTracker",
21    "build_governance_artifacts",
22    "build_model_card",
23    "check_thresholds",
24    "clean_dataframe",
25    "explain_features",
26    "launch",
27    "plot_feature_importance",
28    "promote_challenger",
29    "register_model",
30]
@dataclass
class CleaningSpec:
30@dataclass
31class CleaningSpec:
32    standardize_column_names: bool = True
33    dedup_by: str | None = None
34    drop_nulls_in: list[str] = field(default_factory=list)
35    date_columns: list[str] = field(default_factory=list)
36    categorical_mappings: dict[str, dict] = field(default_factory=dict)
37    derived_columns: list[DerivedColumn] = field(default_factory=list)
38    filters: list[dict] = field(default_factory=list)
CleaningSpec( standardize_column_names: bool = True, dedup_by: str | None = None, drop_nulls_in: list[str] = <factory>, date_columns: list[str] = <factory>, categorical_mappings: dict[str, dict] = <factory>, derived_columns: list[DerivedColumn] = <factory>, filters: list[dict] = <factory>)
standardize_column_names: bool = True
dedup_by: str | None = None
drop_nulls_in: list[str]
date_columns: list[str]
categorical_mappings: dict[str, dict]
derived_columns: list[DerivedColumn]
filters: list[dict]
@dataclass
class DerivedColumn:
23@dataclass
24class DerivedColumn:
25    name: str
26    expression: str  # evaluated with DataFrame.eval()
27    as_binary: bool = False
DerivedColumn(name: str, expression: str, as_binary: bool = False)
name: str
expression: str
as_binary: bool = False
@dataclass
class EvaluationSpec:
41@dataclass
42class EvaluationSpec:
43    min_thresholds: dict[str, float] = field(default_factory=dict)
44    compare_to_champion: bool = True
45    run_shap: bool = True
46    build_model_card: bool = True
EvaluationSpec( min_thresholds: dict[str, float] = <factory>, compare_to_champion: bool = True, run_shap: bool = True, build_model_card: bool = True)
min_thresholds: dict[str, float]
compare_to_champion: bool = True
run_shap: bool = True
build_model_card: bool = True
class ModelMonitor:
 5class ModelMonitor:
 6    """
 7    Monitor ML model performance and data drift on Databricks.
 8
 9    Usage::
10        monitor = ModelMonitor(model_name="customer_churn", model_version=3)
11        monitor.set_baseline(table="catalog.schema.training_data")
12        monitor.set_production(table="catalog.schema.inference_logs")
13        report = monitor.run()
14    """
15
16    def __init__(self, model_name: str, model_version: int | None = None, drift_threshold: float = 0.15):
17        self.model_name = model_name
18        self.model_version = model_version
19        self.drift_threshold = drift_threshold
20        self._baseline_df = None
21        self._production_df = None
22        self._target_col: str | None = None
23        self._prediction_col: str | None = None
24        self._feature_cols: list[str] = []
25        self._retrain_job_name: str | None = None
26
27    def set_baseline(self, df=None, table: str | None = None):
28        self._baseline_df = self._load(df, table)
29        return self
30
31    def set_production(self, df=None, table: str | None = None):
32        self._production_df = self._load(df, table)
33        return self
34
35    def set_target(self, column: str):
36        self._target_col = column
37        return self
38
39    def set_prediction(self, column: str):
40        self._prediction_col = column
41        return self
42
43    def set_features(self, columns: list[str]):
44        self._feature_cols = columns
45        return self
46
47    def set_auto_retrain(self, job_name: str):
48        """Trigger this Databricks job by name if drift exceeds the threshold on run()."""
49        self._retrain_job_name = job_name
50        return self
51
52    def _load(self, df, table):
53        if df is not None:
54            return df
55        from pyspark.sql import SparkSession
56        return SparkSession.getActiveSession().table(table)
57
58    def run(self, save_to: str | None = None) -> MonitorReport:
59        from dashml.metrics import compute_drift, compute_performance, compute_prediction_drift
60
61        results = {}
62        if self._baseline_df is not None and self._production_df is not None:
63            if self._feature_cols:
64                results["drift"] = compute_drift(
65                    self._baseline_df, self._production_df, self._feature_cols, threshold=self.drift_threshold
66                )
67            if self._prediction_col:
68                results["prediction_drift"] = compute_prediction_drift(
69                    self._baseline_df, self._production_df, self._prediction_col
70                )
71        if self._target_col and self._prediction_col and self._production_df is not None:
72            results["performance"] = compute_performance(
73                self._production_df, self._target_col, self._prediction_col
74            )
75
76        if self._retrain_job_name and any(f.get("drifted") for f in results.get("drift", {}).values()):
77            from dashml.retraining import trigger_job
78
79            results["retraining"] = {"job": self._retrain_job_name, "status": trigger_job(self._retrain_job_name)}
80
81        report = MonitorReport(self.model_name, self.model_version, results)
82        if save_to:
83            report.save(save_to)
84        return report

Monitor ML model performance and data drift on Databricks.

Usage:: monitor = ModelMonitor(model_name="customer_churn", model_version=3) monitor.set_baseline(table="catalog.schema.training_data") monitor.set_production(table="catalog.schema.inference_logs") report = monitor.run()

ModelMonitor( model_name: str, model_version: int | None = None, drift_threshold: float = 0.15)
16    def __init__(self, model_name: str, model_version: int | None = None, drift_threshold: float = 0.15):
17        self.model_name = model_name
18        self.model_version = model_version
19        self.drift_threshold = drift_threshold
20        self._baseline_df = None
21        self._production_df = None
22        self._target_col: str | None = None
23        self._prediction_col: str | None = None
24        self._feature_cols: list[str] = []
25        self._retrain_job_name: str | None = None
model_name
model_version
drift_threshold
def set_baseline(self, df=None, table: str | None = None):
27    def set_baseline(self, df=None, table: str | None = None):
28        self._baseline_df = self._load(df, table)
29        return self
def set_production(self, df=None, table: str | None = None):
31    def set_production(self, df=None, table: str | None = None):
32        self._production_df = self._load(df, table)
33        return self
def set_target(self, column: str):
35    def set_target(self, column: str):
36        self._target_col = column
37        return self
def set_prediction(self, column: str):
39    def set_prediction(self, column: str):
40        self._prediction_col = column
41        return self
def set_features(self, columns: list[str]):
43    def set_features(self, columns: list[str]):
44        self._feature_cols = columns
45        return self
def set_auto_retrain(self, job_name: str):
47    def set_auto_retrain(self, job_name: str):
48        """Trigger this Databricks job by name if drift exceeds the threshold on run()."""
49        self._retrain_job_name = job_name
50        return self

Trigger this Databricks job by name if drift exceeds the threshold on run().

def run(self, save_to: str | None = None) -> dashml.monitor.MonitorReport:
58    def run(self, save_to: str | None = None) -> MonitorReport:
59        from dashml.metrics import compute_drift, compute_performance, compute_prediction_drift
60
61        results = {}
62        if self._baseline_df is not None and self._production_df is not None:
63            if self._feature_cols:
64                results["drift"] = compute_drift(
65                    self._baseline_df, self._production_df, self._feature_cols, threshold=self.drift_threshold
66                )
67            if self._prediction_col:
68                results["prediction_drift"] = compute_prediction_drift(
69                    self._baseline_df, self._production_df, self._prediction_col
70                )
71        if self._target_col and self._prediction_col and self._production_df is not None:
72            results["performance"] = compute_performance(
73                self._production_df, self._target_col, self._prediction_col
74            )
75
76        if self._retrain_job_name and any(f.get("drifted") for f in results.get("drift", {}).values()):
77            from dashml.retraining import trigger_job
78
79            results["retraining"] = {"job": self._retrain_job_name, "status": trigger_job(self._retrain_job_name)}
80
81        report = MonitorReport(self.model_name, self.model_version, results)
82        if save_to:
83            report.save(save_to)
84        return report
@dataclass
class ModelSpec:
12@dataclass
13class ModelSpec:
14    type: str = "classifier"  # classifier | regressor | clustering | knn | anomaly
15    target_column: str | None = None
16    feature_columns: list[str] = field(default_factory=list)
17
18    @property
19    def is_supervised(self) -> bool:
20        return self.type in ("classifier", "regressor")
ModelSpec( type: str = 'classifier', target_column: str | None = None, feature_columns: list[str] = <factory>)
type: str = 'classifier'
target_column: str | None = None
feature_columns: list[str]
is_supervised: bool
18    @property
19    def is_supervised(self) -> bool:
20        return self.type in ("classifier", "regressor")
@dataclass
class MonitoringSpec:
49@dataclass
50class MonitoringSpec:
51    drift_threshold: float = 0.15  # PSI: >0.25 significant, 0.10-0.25 moderate
52    row_count_delta_threshold: float = 0.05
53    auto_retrain: bool = False
54    retrain_job_name: str | None = None
MonitoringSpec( drift_threshold: float = 0.15, row_count_delta_threshold: float = 0.05, auto_retrain: bool = False, retrain_job_name: str | None = None)
drift_threshold: float = 0.15
row_count_delta_threshold: float = 0.05
auto_retrain: bool = False
retrain_job_name: str | None = None
@dataclass
class RunConfig:
57@dataclass
58class RunConfig:
59    """Top-level config for one model's lifecycle run."""
60    name: str
61    catalog: str
62    schema_name: str
63    model: ModelSpec = field(default_factory=ModelSpec)
64    cleaning: CleaningSpec = field(default_factory=CleaningSpec)
65    evaluation: EvaluationSpec = field(default_factory=EvaluationSpec)
66    monitoring: MonitoringSpec = field(default_factory=MonitoringSpec)
67
68    @property
69    def registered_name(self) -> str:
70        return f"{self.catalog}.{self.schema_name}.{self.name}"

Top-level config for one model's lifecycle run.

RunConfig( name: str, catalog: str, schema_name: str, model: ModelSpec = <factory>, cleaning: CleaningSpec = <factory>, evaluation: EvaluationSpec = <factory>, monitoring: MonitoringSpec = <factory>)
name: str
catalog: str
schema_name: str
model: ModelSpec
cleaning: CleaningSpec
evaluation: EvaluationSpec
monitoring: MonitoringSpec
registered_name: str
68    @property
69    def registered_name(self) -> str:
70        return f"{self.catalog}.{self.schema_name}.{self.name}"
class RunTracker:
11class RunTracker:
12    """Thin wrapper around one MLflow run's lifecycle."""
13
14    def __init__(self, experiment_path: str):
15        self.experiment_path = experiment_path
16        self._own_run = False
17
18    def __enter__(self) -> RunTracker:
19        import mlflow
20
21        mlflow.set_registry_uri("databricks-uc")
22        mlflow.set_experiment(self.experiment_path)
23        active = mlflow.active_run()
24        if active is None:
25            mlflow.start_run()
26            self._own_run = True
27        self.run_id = mlflow.active_run().info.run_id
28        return self
29
30    def __exit__(self, *exc):
31        import mlflow
32
33        if self._own_run:
34            mlflow.end_run()
35
36    def log_params(self, params: dict) -> None:
37        import mlflow
38
39        mlflow.log_params({k: str(v)[:250] for k, v in params.items()})
40
41    def log_metrics(self, metrics: dict) -> None:
42        import mlflow
43
44        for key, value in metrics.items():
45            try:
46                mlflow.log_metric(key, float(value))
47            except (TypeError, ValueError):  # noqa: PERF203 — per-item fallback, not a hot loop (a handful of metrics per run)
48                mlflow.log_param(f"metric_{key}", str(value))
49
50    def log_text(self, text: str, artifact_path: str) -> None:
51        import mlflow
52
53        mlflow.log_text(text, artifact_path)
54
55    def log_artifact_file(self, path: str, artifact_path: str | None = None) -> None:
56        import mlflow
57
58        mlflow.log_artifact(path, artifact_path)

Thin wrapper around one MLflow run's lifecycle.

RunTracker(experiment_path: str)
14    def __init__(self, experiment_path: str):
15        self.experiment_path = experiment_path
16        self._own_run = False
experiment_path
def log_params(self, params: dict) -> None:
36    def log_params(self, params: dict) -> None:
37        import mlflow
38
39        mlflow.log_params({k: str(v)[:250] for k, v in params.items()})
def log_metrics(self, metrics: dict) -> None:
41    def log_metrics(self, metrics: dict) -> None:
42        import mlflow
43
44        for key, value in metrics.items():
45            try:
46                mlflow.log_metric(key, float(value))
47            except (TypeError, ValueError):  # noqa: PERF203 — per-item fallback, not a hot loop (a handful of metrics per run)
48                mlflow.log_param(f"metric_{key}", str(value))
def log_text(self, text: str, artifact_path: str) -> None:
50    def log_text(self, text: str, artifact_path: str) -> None:
51        import mlflow
52
53        mlflow.log_text(text, artifact_path)
def log_artifact_file(self, path: str, artifact_path: str | None = None) -> None:
55    def log_artifact_file(self, path: str, artifact_path: str | None = None) -> None:
56        import mlflow
57
58        mlflow.log_artifact(path, artifact_path)
def build_governance_artifacts( name: str, version: str, catalog: str, schema_name: str, tables: list[str], feature_columns: list[str], metrics: dict, signature: dict | None = None, sensitive_columns: list[str] | None = None, pii_columns: list[str] | None = None, fairness_metrics: dict | None = None, baseline_metrics: dict | None = None, per_segment_metrics: dict | None = None, shap_by_group: dict[str, dict[str, float]] | None = None, required_approvers: list[str] | None = None, checklist: list[str] | None = None) -> dict[str, str]:
13def build_governance_artifacts(
14    name: str,
15    version: str,
16    catalog: str,
17    schema_name: str,
18    tables: list[str],
19    feature_columns: list[str],
20    metrics: dict,
21    signature: dict | None = None,
22    sensitive_columns: list[str] | None = None,
23    pii_columns: list[str] | None = None,
24    fairness_metrics: dict | None = None,
25    baseline_metrics: dict | None = None,
26    per_segment_metrics: dict | None = None,
27    shap_by_group: dict[str, dict[str, float]] | None = None,
28    required_approvers: list[str] | None = None,
29    checklist: list[str] | None = None,
30) -> dict[str, str]:
31    """Returns {filename: content} for the standard governance file set."""
32    files = {
33        "signature.json": json.dumps(signature or {}, indent=2),
34        "features.json": json.dumps(
35            {
36                "feature_columns": feature_columns,
37                "sensitive_columns": sensitive_columns or [],
38                "pii_columns": pii_columns or [],
39            },
40            indent=2,
41        ),
42        "metrics.json": json.dumps(
43            {
44                "metrics": metrics,
45                "per_segment_metrics": per_segment_metrics or {},
46                "baseline_comparison": _delta(metrics, baseline_metrics) if baseline_metrics else None,
47                "top_features": _top_shap(shap_by_group),
48            },
49            indent=2,
50        ),
51        "data_sources.yaml": yaml.dump(
52            {"catalog": catalog, "schema": schema_name, "tables": [f"{catalog}.{schema_name}.{t}" for t in tables]},
53            sort_keys=False,
54        ),
55        "fairness_report.md": _fairness_report(sensitive_columns, fairness_metrics),
56        "approval_record.json": json.dumps(
57            {
58                "model": name,
59                "version": version,
60                "state": "pending",
61                "required_approvers": required_approvers or [],
62                "checklist": dict.fromkeys(checklist or [], False),
63            },
64            indent=2,
65        ),
66    }
67    return files

Returns {filename: content} for the standard governance file set.

def build_model_card( name: str, version: str, model_type: str, catalog: str, schema_name: str, tables: list[str], target_column: str | None, feature_columns: list[str], params: dict, metrics: dict, run_id: str | None = None, shap_by_group: dict[str, dict[str, float]] | None = None) -> str:
106def build_model_card(
107    name: str,
108    version: str,
109    model_type: str,
110    catalog: str,
111    schema_name: str,
112    tables: list[str],
113    target_column: str | None,
114    feature_columns: list[str],
115    params: dict,
116    metrics: dict,
117    run_id: str | None = None,
118    shap_by_group: dict[str, dict[str, float]] | None = None,
119) -> str:
120    """Render a Markdown model card. Doesn't write or log it — caller decides where it goes."""
121    lines = [
122        f"# Model Card — {name}",
123        "",
124        "## Details",
125        f"- **Version:** {version}",
126        f"- **Type:** {model_type}",
127        f"- **Created:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}",
128    ]
129    if run_id:
130        lines.append(f"- **MLflow run:** `{run_id}`")
131
132    lines += [
133        "",
134        "## Data",
135        f"- **Source:** Unity Catalog `{catalog}.{schema_name}`",
136        f"- **Tables:** {', '.join(tables) if tables else '—'}",
137    ]
138    if target_column:
139        lines.append(f"- **Target column:** `{target_column}`")
140    if feature_columns:
141        lines.append(f"- **Feature columns:** {', '.join(feature_columns)}")
142
143    lines += ["", "## Parameters"]
144    lines += [f"- `{k}`: {v}" for k, v in sorted(params.items())] or ["_none logged_"]
145
146    lines += ["", "## Metrics"]
147    lines += [f"- `{k}`: {v}" for k, v in sorted(metrics.items())] or ["_none logged_"]
148
149    if shap_by_group:
150        lines += ["", "## Feature importance (SHAP)"]
151        for group, importances in shap_by_group.items():
152            top = sorted(importances.items(), key=lambda kv: kv[1], reverse=True)[:10]
153            lines.append(f"**{group}**")
154            lines += [f"  - `{feat}`: {val:.4f}" for feat, val in top]
155
156    lines += [
157        "",
158        "## Limitations",
159        "Evaluated on the configured evaluation split only — verify performance holds",
160        "on your actual production population before relying on this model.",
161    ]
162    return "\n".join(lines)

Render a Markdown model card. Doesn't write or log it — caller decides where it goes.

def check_thresholds( metrics: dict[str, float], min_thresholds: dict[str, float]) -> dict[str, bool]:
14def check_thresholds(metrics: dict[str, float], min_thresholds: dict[str, float]) -> dict[str, bool]:
15    """Per-metric pass/fail against configured minimums."""
16    return {name: metrics.get(name, float("-inf")) >= minimum for name, minimum in min_thresholds.items()}

Per-metric pass/fail against configured minimums.

def clean_dataframe( df: pandas.DataFrame, spec: CleaningSpec) -> pandas.DataFrame:
19def clean_dataframe(df: pd.DataFrame, spec: CleaningSpec) -> pd.DataFrame:
20    """Apply the configured cleaning steps, in order, and return a new DataFrame."""
21    if spec.standardize_column_names:
22        df = _standardize_column_names(df)
23
24    if spec.drop_nulls_in:
25        df = df.dropna(subset=spec.drop_nulls_in)
26
27    if spec.dedup_by:
28        df = df.drop_duplicates(subset=[spec.dedup_by])
29
30    for col in spec.date_columns:
31        if col in df.columns:
32            df[col] = pd.to_datetime(df[col], errors="coerce")
33
34    for col, mapping in spec.categorical_mappings.items():
35        if col in df.columns:
36            df[col] = df[col].map(mapping)
37
38    for derived in spec.derived_columns:
39        df = _add_derived_column(df, derived)
40
41    for filt in spec.filters:
42        df = _apply_filter(df, filt)
43
44    return df

Apply the configured cleaning steps, in order, and return a new DataFrame.

def explain_features( model: Any, X: pandas.DataFrame | numpy.ndarray, feature_names: list[str] | None = None, sample_size: int = 100, background_size: int = 50) -> dict[str, float] | None:
19def explain_features(
20    model: Any,
21    X: pd.DataFrame | np.ndarray,
22    feature_names: list[str] | None = None,
23    sample_size: int = 100,
24    background_size: int = 50,
25) -> dict[str, float] | None:
26    """Mean absolute SHAP value per feature. Works with any model exposing
27    predict() or scikit-learn-style kneighbors(). Returns None if `shap` isn't
28    installed, the model exposes neither method, or there are too few rows."""
29    try:
30        import shap
31    except ImportError:
32        return None
33
34    if isinstance(X, pd.DataFrame):
35        feature_names = feature_names or list(X.columns)
36        X_numeric = X.copy()
37        for col in X_numeric.columns:
38            if not pd.api.types.is_numeric_dtype(X_numeric[col]):
39                X_numeric[col] = pd.factorize(X_numeric[col])[0]
40        X_arr = X_numeric.to_numpy(dtype=np.float64)
41    else:
42        X_arr = np.asarray(X, dtype=np.float64)
43        feature_names = feature_names or [f"feature_{i}" for i in range(X_arr.shape[1])]
44
45    X_arr = np.nan_to_num(X_arr)
46    if len(X_arr) < 5:
47        return None
48
49    predict_fn = _predict_function(model, len(X_arr))
50    if predict_fn is None:
51        return None
52
53    try:
54        background = shap.sample(X_arr, min(background_size, len(X_arr)))
55        explainer = shap.KernelExplainer(predict_fn, background)
56        shap_values = explainer.shap_values(X_arr[: min(sample_size, len(X_arr))])
57        return dict(zip(feature_names, np.abs(shap_values).mean(axis=0).tolist()))
58    except Exception:
59        return None

Mean absolute SHAP value per feature. Works with any model exposing predict() or scikit-learn-style kneighbors(). Returns None if shap isn't installed, the model exposes neither method, or there are too few rows.

def launch():
 6def launch():
 7    try:
 8        import ipywidgets as w
 9        from IPython.display import display
10    except ImportError:
11        raise RuntimeError("ipywidgets required. Run: %pip install ipywidgets") from None
12
13    import dashui
14
15    tabs = w.Tab(children=[
16        _monitor_tab(w, dashui),
17        _preprocess_tab(w, dashui),
18        _evaluate_tab(w, dashui),
19        _promote_tab(w, dashui),
20    ])
21    for i, title in enumerate(["Monitor", "Preprocess", "Evaluate", "Promote"]):
22        tabs.set_title(i, title)
23
24    ui = dashui.card([
25        dashui.header("DashML — ML Lifecycle Management", library="dashml"),
26        tabs,
27    ])
28    display(ui)
def plot_feature_importance( shap_by_group: dict[str, dict[str, float]], output_dir: str | None = None) -> list[str]:
 73def plot_feature_importance(shap_by_group: dict[str, dict[str, float]], output_dir: str | None = None) -> list[str]:
 74    """One horizontal-bar PNG per group in shap_by_group. Returns saved file paths."""
 75    import os
 76    import tempfile
 77    import matplotlib
 78    matplotlib.use("Agg")
 79    import matplotlib.pyplot as plt
 80
 81    output_dir = output_dir or tempfile.mkdtemp()
 82    paths = []
 83    for group, importances in shap_by_group.items():
 84        if not importances:
 85            continue
 86        top = sorted(importances.items(), key=lambda kv: kv[1], reverse=True)[:15]
 87        names, values = [t[0] for t in top], [t[1] for t in top]
 88
 89        fig, ax = plt.subplots(figsize=(10, max(4, len(names) * 0.4)))
 90        ax.barh(range(len(names)), values, color="#2A9D90")
 91        ax.set_yticks(range(len(names)))
 92        ax.set_yticklabels(names)
 93        ax.invert_yaxis()
 94        ax.set_xlabel("Mean |SHAP value|")
 95        ax.set_title(f"Feature importance — {group}")
 96        plt.tight_layout()
 97
 98        safe_name = "".join(c if c.isalnum() else "_" for c in group)[:30]
 99        path = os.path.join(output_dir, f"importance_{safe_name}.png")
100        plt.savefig(path, dpi=150, bbox_inches="tight")
101        plt.close(fig)
102        paths.append(path)
103    return paths

One horizontal-bar PNG per group in shap_by_group. Returns saved file paths.

def promote_challenger( catalog: str, schema_name: str, model_name: str, run_id: str, metric_key: str = 'accuracy', min_threshold: float | None = None, direction: str = 'max', dry_run: bool = False) -> str:
 92def promote_challenger(
 93    catalog: str,
 94    schema_name: str,
 95    model_name: str,
 96    run_id: str,
 97    metric_key: str = "accuracy",
 98    min_threshold: float | None = None,
 99    direction: str = "max",
100    dry_run: bool = False,
101) -> str:
102    """Compare the given run's metric against the current @champion and, if it wins,
103    promote it. Returns one of: promoted, no_champion, retained, below_threshold,
104    no_metric, would_promote, no_version, error."""
105    import mlflow
106    from mlflow.tracking import MlflowClient
107
108    mlflow.set_registry_uri("databricks-uc")
109    client = MlflowClient(registry_uri="databricks-uc")
110    uc_name = registered_model_name(catalog, schema_name, model_name)
111
112    run = client.get_run(run_id)
113    new_value = run.data.metrics.get(metric_key)
114    if new_value is None:
115        return "no_metric"
116    if min_threshold is not None:
117        beats_floor = new_value >= min_threshold if direction == "max" else new_value <= min_threshold
118        if not beats_floor:
119            return "below_threshold"
120
121    try:
122        champion = client.get_model_version_by_alias(uc_name, "champion")
123        champion_value = client.get_run(champion.run_id).data.metrics.get(metric_key)
124        improved = new_value > champion_value if direction == "max" else new_value < champion_value
125        has_champion = True
126    except Exception:
127        improved, has_champion = True, False
128
129    if not improved:
130        return "retained"
131    if dry_run:
132        return "would_promote"
133
134    versions = client.search_model_versions(f"name='{uc_name}'")
135    match = next((v for v in versions if v.run_id == run_id), None)
136    if match is None:
137        return "no_version"
138
139    try:
140        client.set_registered_model_alias(uc_name, "champion", match.version)
141        return "promoted" if has_champion else "no_champion"
142    except Exception:
143        return "error"

Compare the given run's metric against the current @champion and, if it wins, promote it. Returns one of: promoted, no_champion, retained, below_threshold, no_metric, would_promote, no_version, error.

def register_model( catalog: str, schema_name: str, model_name: str, model: Any, run_id: str | None = None) -> bool:
65def register_model(catalog: str, schema_name: str, model_name: str, model: Any, run_id: str | None = None) -> bool:
66    """Register a fitted pyfunc-compatible model into UC via MLflow."""
67    import mlflow
68
69    mlflow.set_registry_uri("databricks-uc")
70    uc_name = registered_model_name(catalog, schema_name, model_name)
71    try:
72        mlflow.pyfunc.log_model(
73            artifact_path="model",
74            python_model=_PyfuncWrapper(model),
75            registered_model_name=uc_name,
76        )
77        return True
78    except Exception:
79        return False

Register a fitted pyfunc-compatible model into UC via MLflow.