dashml
DashML — ML lifecycle management for Databricks: preprocessing, evaluation, drift monitoring, governance, and champion/challenger promotion.
1"""DashML — ML lifecycle management for Databricks: preprocessing, evaluation, 2drift monitoring, governance, and champion/challenger promotion.""" 3from dashml.config import CleaningSpec, DerivedColumn, EvaluationSpec, ModelSpec, MonitoringSpec, RunConfig 4from dashml.evaluation import build_model_card, check_thresholds, explain_features, plot_feature_importance 5from dashml.governance import build_governance_artifacts 6from dashml.monitor import ModelMonitor 7from dashml.preprocessing import clean_dataframe 8from dashml.registry import RunTracker, promote_challenger, register_model 9from dashml.ui import launch 10 11__version__ = "0.1.4" 12__all__ = [ 13 "CleaningSpec", 14 "DerivedColumn", 15 "EvaluationSpec", 16 "ModelMonitor", 17 "ModelSpec", 18 "MonitoringSpec", 19 "RunConfig", 20 "RunTracker", 21 "build_governance_artifacts", 22 "build_model_card", 23 "check_thresholds", 24 "clean_dataframe", 25 "explain_features", 26 "launch", 27 "plot_feature_importance", 28 "promote_challenger", 29 "register_model", 30]
30@dataclass 31class CleaningSpec: 32 standardize_column_names: bool = True 33 dedup_by: str | None = None 34 drop_nulls_in: list[str] = field(default_factory=list) 35 date_columns: list[str] = field(default_factory=list) 36 categorical_mappings: dict[str, dict] = field(default_factory=dict) 37 derived_columns: list[DerivedColumn] = field(default_factory=list) 38 filters: list[dict] = field(default_factory=list)
41@dataclass 42class EvaluationSpec: 43 min_thresholds: dict[str, float] = field(default_factory=dict) 44 compare_to_champion: bool = True 45 run_shap: bool = True 46 build_model_card: bool = True
5class ModelMonitor: 6 """ 7 Monitor ML model performance and data drift on Databricks. 8 9 Usage:: 10 monitor = ModelMonitor(model_name="customer_churn", model_version=3) 11 monitor.set_baseline(table="catalog.schema.training_data") 12 monitor.set_production(table="catalog.schema.inference_logs") 13 report = monitor.run() 14 """ 15 16 def __init__(self, model_name: str, model_version: int | None = None, drift_threshold: float = 0.15): 17 self.model_name = model_name 18 self.model_version = model_version 19 self.drift_threshold = drift_threshold 20 self._baseline_df = None 21 self._production_df = None 22 self._target_col: str | None = None 23 self._prediction_col: str | None = None 24 self._feature_cols: list[str] = [] 25 self._retrain_job_name: str | None = None 26 27 def set_baseline(self, df=None, table: str | None = None): 28 self._baseline_df = self._load(df, table) 29 return self 30 31 def set_production(self, df=None, table: str | None = None): 32 self._production_df = self._load(df, table) 33 return self 34 35 def set_target(self, column: str): 36 self._target_col = column 37 return self 38 39 def set_prediction(self, column: str): 40 self._prediction_col = column 41 return self 42 43 def set_features(self, columns: list[str]): 44 self._feature_cols = columns 45 return self 46 47 def set_auto_retrain(self, job_name: str): 48 """Trigger this Databricks job by name if drift exceeds the threshold on run().""" 49 self._retrain_job_name = job_name 50 return self 51 52 def _load(self, df, table): 53 if df is not None: 54 return df 55 from pyspark.sql import SparkSession 56 return SparkSession.getActiveSession().table(table) 57 58 def run(self, save_to: str | None = None) -> MonitorReport: 59 from dashml.metrics import compute_drift, compute_performance, compute_prediction_drift 60 61 results = {} 62 if self._baseline_df is not None and self._production_df is not None: 63 if self._feature_cols: 64 results["drift"] = compute_drift( 65 self._baseline_df, self._production_df, self._feature_cols, threshold=self.drift_threshold 66 ) 67 if self._prediction_col: 68 results["prediction_drift"] = compute_prediction_drift( 69 self._baseline_df, self._production_df, self._prediction_col 70 ) 71 if self._target_col and self._prediction_col and self._production_df is not None: 72 results["performance"] = compute_performance( 73 self._production_df, self._target_col, self._prediction_col 74 ) 75 76 if self._retrain_job_name and any(f.get("drifted") for f in results.get("drift", {}).values()): 77 from dashml.retraining import trigger_job 78 79 results["retraining"] = {"job": self._retrain_job_name, "status": trigger_job(self._retrain_job_name)} 80 81 report = MonitorReport(self.model_name, self.model_version, results) 82 if save_to: 83 report.save(save_to) 84 return report
Monitor ML model performance and data drift on Databricks.
Usage:: monitor = ModelMonitor(model_name="customer_churn", model_version=3) monitor.set_baseline(table="catalog.schema.training_data") monitor.set_production(table="catalog.schema.inference_logs") report = monitor.run()
16 def __init__(self, model_name: str, model_version: int | None = None, drift_threshold: float = 0.15): 17 self.model_name = model_name 18 self.model_version = model_version 19 self.drift_threshold = drift_threshold 20 self._baseline_df = None 21 self._production_df = None 22 self._target_col: str | None = None 23 self._prediction_col: str | None = None 24 self._feature_cols: list[str] = [] 25 self._retrain_job_name: str | None = None
47 def set_auto_retrain(self, job_name: str): 48 """Trigger this Databricks job by name if drift exceeds the threshold on run().""" 49 self._retrain_job_name = job_name 50 return self
Trigger this Databricks job by name if drift exceeds the threshold on run().
58 def run(self, save_to: str | None = None) -> MonitorReport: 59 from dashml.metrics import compute_drift, compute_performance, compute_prediction_drift 60 61 results = {} 62 if self._baseline_df is not None and self._production_df is not None: 63 if self._feature_cols: 64 results["drift"] = compute_drift( 65 self._baseline_df, self._production_df, self._feature_cols, threshold=self.drift_threshold 66 ) 67 if self._prediction_col: 68 results["prediction_drift"] = compute_prediction_drift( 69 self._baseline_df, self._production_df, self._prediction_col 70 ) 71 if self._target_col and self._prediction_col and self._production_df is not None: 72 results["performance"] = compute_performance( 73 self._production_df, self._target_col, self._prediction_col 74 ) 75 76 if self._retrain_job_name and any(f.get("drifted") for f in results.get("drift", {}).values()): 77 from dashml.retraining import trigger_job 78 79 results["retraining"] = {"job": self._retrain_job_name, "status": trigger_job(self._retrain_job_name)} 80 81 report = MonitorReport(self.model_name, self.model_version, results) 82 if save_to: 83 report.save(save_to) 84 return report
12@dataclass 13class ModelSpec: 14 type: str = "classifier" # classifier | regressor | clustering | knn | anomaly 15 target_column: str | None = None 16 feature_columns: list[str] = field(default_factory=list) 17 18 @property 19 def is_supervised(self) -> bool: 20 return self.type in ("classifier", "regressor")
49@dataclass 50class MonitoringSpec: 51 drift_threshold: float = 0.15 # PSI: >0.25 significant, 0.10-0.25 moderate 52 row_count_delta_threshold: float = 0.05 53 auto_retrain: bool = False 54 retrain_job_name: str | None = None
57@dataclass 58class RunConfig: 59 """Top-level config for one model's lifecycle run.""" 60 name: str 61 catalog: str 62 schema_name: str 63 model: ModelSpec = field(default_factory=ModelSpec) 64 cleaning: CleaningSpec = field(default_factory=CleaningSpec) 65 evaluation: EvaluationSpec = field(default_factory=EvaluationSpec) 66 monitoring: MonitoringSpec = field(default_factory=MonitoringSpec) 67 68 @property 69 def registered_name(self) -> str: 70 return f"{self.catalog}.{self.schema_name}.{self.name}"
Top-level config for one model's lifecycle run.
11class RunTracker: 12 """Thin wrapper around one MLflow run's lifecycle.""" 13 14 def __init__(self, experiment_path: str): 15 self.experiment_path = experiment_path 16 self._own_run = False 17 18 def __enter__(self) -> RunTracker: 19 import mlflow 20 21 mlflow.set_registry_uri("databricks-uc") 22 mlflow.set_experiment(self.experiment_path) 23 active = mlflow.active_run() 24 if active is None: 25 mlflow.start_run() 26 self._own_run = True 27 self.run_id = mlflow.active_run().info.run_id 28 return self 29 30 def __exit__(self, *exc): 31 import mlflow 32 33 if self._own_run: 34 mlflow.end_run() 35 36 def log_params(self, params: dict) -> None: 37 import mlflow 38 39 mlflow.log_params({k: str(v)[:250] for k, v in params.items()}) 40 41 def log_metrics(self, metrics: dict) -> None: 42 import mlflow 43 44 for key, value in metrics.items(): 45 try: 46 mlflow.log_metric(key, float(value)) 47 except (TypeError, ValueError): # noqa: PERF203 — per-item fallback, not a hot loop (a handful of metrics per run) 48 mlflow.log_param(f"metric_{key}", str(value)) 49 50 def log_text(self, text: str, artifact_path: str) -> None: 51 import mlflow 52 53 mlflow.log_text(text, artifact_path) 54 55 def log_artifact_file(self, path: str, artifact_path: str | None = None) -> None: 56 import mlflow 57 58 mlflow.log_artifact(path, artifact_path)
Thin wrapper around one MLflow run's lifecycle.
41 def log_metrics(self, metrics: dict) -> None: 42 import mlflow 43 44 for key, value in metrics.items(): 45 try: 46 mlflow.log_metric(key, float(value)) 47 except (TypeError, ValueError): # noqa: PERF203 — per-item fallback, not a hot loop (a handful of metrics per run) 48 mlflow.log_param(f"metric_{key}", str(value))
13def build_governance_artifacts( 14 name: str, 15 version: str, 16 catalog: str, 17 schema_name: str, 18 tables: list[str], 19 feature_columns: list[str], 20 metrics: dict, 21 signature: dict | None = None, 22 sensitive_columns: list[str] | None = None, 23 pii_columns: list[str] | None = None, 24 fairness_metrics: dict | None = None, 25 baseline_metrics: dict | None = None, 26 per_segment_metrics: dict | None = None, 27 shap_by_group: dict[str, dict[str, float]] | None = None, 28 required_approvers: list[str] | None = None, 29 checklist: list[str] | None = None, 30) -> dict[str, str]: 31 """Returns {filename: content} for the standard governance file set.""" 32 files = { 33 "signature.json": json.dumps(signature or {}, indent=2), 34 "features.json": json.dumps( 35 { 36 "feature_columns": feature_columns, 37 "sensitive_columns": sensitive_columns or [], 38 "pii_columns": pii_columns or [], 39 }, 40 indent=2, 41 ), 42 "metrics.json": json.dumps( 43 { 44 "metrics": metrics, 45 "per_segment_metrics": per_segment_metrics or {}, 46 "baseline_comparison": _delta(metrics, baseline_metrics) if baseline_metrics else None, 47 "top_features": _top_shap(shap_by_group), 48 }, 49 indent=2, 50 ), 51 "data_sources.yaml": yaml.dump( 52 {"catalog": catalog, "schema": schema_name, "tables": [f"{catalog}.{schema_name}.{t}" for t in tables]}, 53 sort_keys=False, 54 ), 55 "fairness_report.md": _fairness_report(sensitive_columns, fairness_metrics), 56 "approval_record.json": json.dumps( 57 { 58 "model": name, 59 "version": version, 60 "state": "pending", 61 "required_approvers": required_approvers or [], 62 "checklist": dict.fromkeys(checklist or [], False), 63 }, 64 indent=2, 65 ), 66 } 67 return files
Returns {filename: content} for the standard governance file set.
106def build_model_card( 107 name: str, 108 version: str, 109 model_type: str, 110 catalog: str, 111 schema_name: str, 112 tables: list[str], 113 target_column: str | None, 114 feature_columns: list[str], 115 params: dict, 116 metrics: dict, 117 run_id: str | None = None, 118 shap_by_group: dict[str, dict[str, float]] | None = None, 119) -> str: 120 """Render a Markdown model card. Doesn't write or log it — caller decides where it goes.""" 121 lines = [ 122 f"# Model Card — {name}", 123 "", 124 "## Details", 125 f"- **Version:** {version}", 126 f"- **Type:** {model_type}", 127 f"- **Created:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}", 128 ] 129 if run_id: 130 lines.append(f"- **MLflow run:** `{run_id}`") 131 132 lines += [ 133 "", 134 "## Data", 135 f"- **Source:** Unity Catalog `{catalog}.{schema_name}`", 136 f"- **Tables:** {', '.join(tables) if tables else '—'}", 137 ] 138 if target_column: 139 lines.append(f"- **Target column:** `{target_column}`") 140 if feature_columns: 141 lines.append(f"- **Feature columns:** {', '.join(feature_columns)}") 142 143 lines += ["", "## Parameters"] 144 lines += [f"- `{k}`: {v}" for k, v in sorted(params.items())] or ["_none logged_"] 145 146 lines += ["", "## Metrics"] 147 lines += [f"- `{k}`: {v}" for k, v in sorted(metrics.items())] or ["_none logged_"] 148 149 if shap_by_group: 150 lines += ["", "## Feature importance (SHAP)"] 151 for group, importances in shap_by_group.items(): 152 top = sorted(importances.items(), key=lambda kv: kv[1], reverse=True)[:10] 153 lines.append(f"**{group}**") 154 lines += [f" - `{feat}`: {val:.4f}" for feat, val in top] 155 156 lines += [ 157 "", 158 "## Limitations", 159 "Evaluated on the configured evaluation split only — verify performance holds", 160 "on your actual production population before relying on this model.", 161 ] 162 return "\n".join(lines)
Render a Markdown model card. Doesn't write or log it — caller decides where it goes.
14def check_thresholds(metrics: dict[str, float], min_thresholds: dict[str, float]) -> dict[str, bool]: 15 """Per-metric pass/fail against configured minimums.""" 16 return {name: metrics.get(name, float("-inf")) >= minimum for name, minimum in min_thresholds.items()}
Per-metric pass/fail against configured minimums.
19def clean_dataframe(df: pd.DataFrame, spec: CleaningSpec) -> pd.DataFrame: 20 """Apply the configured cleaning steps, in order, and return a new DataFrame.""" 21 if spec.standardize_column_names: 22 df = _standardize_column_names(df) 23 24 if spec.drop_nulls_in: 25 df = df.dropna(subset=spec.drop_nulls_in) 26 27 if spec.dedup_by: 28 df = df.drop_duplicates(subset=[spec.dedup_by]) 29 30 for col in spec.date_columns: 31 if col in df.columns: 32 df[col] = pd.to_datetime(df[col], errors="coerce") 33 34 for col, mapping in spec.categorical_mappings.items(): 35 if col in df.columns: 36 df[col] = df[col].map(mapping) 37 38 for derived in spec.derived_columns: 39 df = _add_derived_column(df, derived) 40 41 for filt in spec.filters: 42 df = _apply_filter(df, filt) 43 44 return df
Apply the configured cleaning steps, in order, and return a new DataFrame.
19def explain_features( 20 model: Any, 21 X: pd.DataFrame | np.ndarray, 22 feature_names: list[str] | None = None, 23 sample_size: int = 100, 24 background_size: int = 50, 25) -> dict[str, float] | None: 26 """Mean absolute SHAP value per feature. Works with any model exposing 27 predict() or scikit-learn-style kneighbors(). Returns None if `shap` isn't 28 installed, the model exposes neither method, or there are too few rows.""" 29 try: 30 import shap 31 except ImportError: 32 return None 33 34 if isinstance(X, pd.DataFrame): 35 feature_names = feature_names or list(X.columns) 36 X_numeric = X.copy() 37 for col in X_numeric.columns: 38 if not pd.api.types.is_numeric_dtype(X_numeric[col]): 39 X_numeric[col] = pd.factorize(X_numeric[col])[0] 40 X_arr = X_numeric.to_numpy(dtype=np.float64) 41 else: 42 X_arr = np.asarray(X, dtype=np.float64) 43 feature_names = feature_names or [f"feature_{i}" for i in range(X_arr.shape[1])] 44 45 X_arr = np.nan_to_num(X_arr) 46 if len(X_arr) < 5: 47 return None 48 49 predict_fn = _predict_function(model, len(X_arr)) 50 if predict_fn is None: 51 return None 52 53 try: 54 background = shap.sample(X_arr, min(background_size, len(X_arr))) 55 explainer = shap.KernelExplainer(predict_fn, background) 56 shap_values = explainer.shap_values(X_arr[: min(sample_size, len(X_arr))]) 57 return dict(zip(feature_names, np.abs(shap_values).mean(axis=0).tolist())) 58 except Exception: 59 return None
Mean absolute SHAP value per feature. Works with any model exposing
predict() or scikit-learn-style kneighbors(). Returns None if shap isn't
installed, the model exposes neither method, or there are too few rows.
6def launch(): 7 try: 8 import ipywidgets as w 9 from IPython.display import display 10 except ImportError: 11 raise RuntimeError("ipywidgets required. Run: %pip install ipywidgets") from None 12 13 import dashui 14 15 tabs = w.Tab(children=[ 16 _monitor_tab(w, dashui), 17 _preprocess_tab(w, dashui), 18 _evaluate_tab(w, dashui), 19 _promote_tab(w, dashui), 20 ]) 21 for i, title in enumerate(["Monitor", "Preprocess", "Evaluate", "Promote"]): 22 tabs.set_title(i, title) 23 24 ui = dashui.card([ 25 dashui.header("DashML — ML Lifecycle Management", library="dashml"), 26 tabs, 27 ]) 28 display(ui)
73def plot_feature_importance(shap_by_group: dict[str, dict[str, float]], output_dir: str | None = None) -> list[str]: 74 """One horizontal-bar PNG per group in shap_by_group. Returns saved file paths.""" 75 import os 76 import tempfile 77 import matplotlib 78 matplotlib.use("Agg") 79 import matplotlib.pyplot as plt 80 81 output_dir = output_dir or tempfile.mkdtemp() 82 paths = [] 83 for group, importances in shap_by_group.items(): 84 if not importances: 85 continue 86 top = sorted(importances.items(), key=lambda kv: kv[1], reverse=True)[:15] 87 names, values = [t[0] for t in top], [t[1] for t in top] 88 89 fig, ax = plt.subplots(figsize=(10, max(4, len(names) * 0.4))) 90 ax.barh(range(len(names)), values, color="#2A9D90") 91 ax.set_yticks(range(len(names))) 92 ax.set_yticklabels(names) 93 ax.invert_yaxis() 94 ax.set_xlabel("Mean |SHAP value|") 95 ax.set_title(f"Feature importance — {group}") 96 plt.tight_layout() 97 98 safe_name = "".join(c if c.isalnum() else "_" for c in group)[:30] 99 path = os.path.join(output_dir, f"importance_{safe_name}.png") 100 plt.savefig(path, dpi=150, bbox_inches="tight") 101 plt.close(fig) 102 paths.append(path) 103 return paths
One horizontal-bar PNG per group in shap_by_group. Returns saved file paths.
92def promote_challenger( 93 catalog: str, 94 schema_name: str, 95 model_name: str, 96 run_id: str, 97 metric_key: str = "accuracy", 98 min_threshold: float | None = None, 99 direction: str = "max", 100 dry_run: bool = False, 101) -> str: 102 """Compare the given run's metric against the current @champion and, if it wins, 103 promote it. Returns one of: promoted, no_champion, retained, below_threshold, 104 no_metric, would_promote, no_version, error.""" 105 import mlflow 106 from mlflow.tracking import MlflowClient 107 108 mlflow.set_registry_uri("databricks-uc") 109 client = MlflowClient(registry_uri="databricks-uc") 110 uc_name = registered_model_name(catalog, schema_name, model_name) 111 112 run = client.get_run(run_id) 113 new_value = run.data.metrics.get(metric_key) 114 if new_value is None: 115 return "no_metric" 116 if min_threshold is not None: 117 beats_floor = new_value >= min_threshold if direction == "max" else new_value <= min_threshold 118 if not beats_floor: 119 return "below_threshold" 120 121 try: 122 champion = client.get_model_version_by_alias(uc_name, "champion") 123 champion_value = client.get_run(champion.run_id).data.metrics.get(metric_key) 124 improved = new_value > champion_value if direction == "max" else new_value < champion_value 125 has_champion = True 126 except Exception: 127 improved, has_champion = True, False 128 129 if not improved: 130 return "retained" 131 if dry_run: 132 return "would_promote" 133 134 versions = client.search_model_versions(f"name='{uc_name}'") 135 match = next((v for v in versions if v.run_id == run_id), None) 136 if match is None: 137 return "no_version" 138 139 try: 140 client.set_registered_model_alias(uc_name, "champion", match.version) 141 return "promoted" if has_champion else "no_champion" 142 except Exception: 143 return "error"
Compare the given run's metric against the current @champion and, if it wins, promote it. Returns one of: promoted, no_champion, retained, below_threshold, no_metric, would_promote, no_version, error.
65def register_model(catalog: str, schema_name: str, model_name: str, model: Any, run_id: str | None = None) -> bool: 66 """Register a fitted pyfunc-compatible model into UC via MLflow.""" 67 import mlflow 68 69 mlflow.set_registry_uri("databricks-uc") 70 uc_name = registered_model_name(catalog, schema_name, model_name) 71 try: 72 mlflow.pyfunc.log_model( 73 artifact_path="model", 74 python_model=_PyfuncWrapper(model), 75 registered_model_name=uc_name, 76 ) 77 return True 78 except Exception: 79 return False
Register a fitted pyfunc-compatible model into UC via MLflow.