MLOps Best Practices¶
Home | Best Practices | Databricks | MLOps
Best practices for MLOps on Azure Databricks.
Experiment Tracking¶
MLflow Configuration¶
import mlflow
# Set experiment
mlflow.set_experiment("/Experiments/customer-churn")
# Enable autologging
mlflow.autolog()
# Manual logging with context
with mlflow.start_run(run_name="random_forest_v1") as run:
# Log parameters
mlflow.log_params({
"n_estimators": 100,
"max_depth": 10
})
# Train model
model.fit(X_train, y_train)
# Log metrics
mlflow.log_metrics({
"accuracy": accuracy_score(y_test, y_pred),
"f1_score": f1_score(y_test, y_pred)
})
# Log artifacts
mlflow.log_artifact("feature_importance.png")
# Log model
mlflow.sklearn.log_model(model, "model")
Model Registry¶
Model Lifecycle¶
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Register model
model_uri = f"runs:/{run.info.run_id}/model"
mv = mlflow.register_model(model_uri, "customer_churn_model")
# Add description
client.update_registered_model(
name="customer_churn_model",
description="Predicts customer churn probability"
)
# Stage transitions
# None -> Staging -> Production -> Archived
client.transition_model_version_stage(
name="customer_churn_model",
version=mv.version,
stage="Staging"
)
Model Approval Workflow¶
def approve_model(model_name: str, version: int):
"""Approve model after validation."""
client = MlflowClient()
# Get model metrics
model_version = client.get_model_version(model_name, str(version))
run = client.get_run(model_version.run_id)
metrics = run.data.metrics
# Validation criteria
if metrics.get("accuracy", 0) > 0.85 and metrics.get("f1_score", 0) > 0.80:
client.transition_model_version_stage(
name=model_name,
version=version,
stage="Production",
archive_existing_versions=True
)
return True
return False
Feature Store¶
Feature Table Management¶
from databricks.feature_engineering import FeatureEngineeringClient
fe = FeatureEngineeringClient()
# Create feature table
fe.create_table(
name="feature_store.customer_features",
primary_keys=["customer_id"],
timestamp_keys=["update_timestamp"],
df=features_df,
description="Customer transaction features"
)
# Update features (overwrite)
fe.write_table(
name="feature_store.customer_features",
df=updated_features,
mode="overwrite"
)
Point-in-Time Training¶
from databricks.feature_engineering import FeatureLookup
# Create training set with point-in-time correctness
training_set = fe.create_training_set(
df=labels_df,
feature_lookups=[
FeatureLookup(
table_name="feature_store.customer_features",
lookup_key=["customer_id"],
timestamp_lookup_key="event_timestamp"
)
],
label="is_churned"
)
Model Serving¶
Model Serving Endpoint¶
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()
# Create serving endpoint
endpoint = w.serving_endpoints.create_and_wait(
name="churn-prediction",
config={
"served_entities": [{
"entity_name": "customer_churn_model",
"entity_version": "1",
"workload_size": "Small",
"scale_to_zero_enabled": True
}]
}
)
# A/B testing with traffic split
config = {
"served_entities": [
{"entity_name": "model", "entity_version": "1", "workload_size": "Small"},
{"entity_name": "model", "entity_version": "2", "workload_size": "Small"}
],
"traffic_config": {
"routes": [
{"served_model_name": "model-1", "traffic_percentage": 90},
{"served_model_name": "model-2", "traffic_percentage": 10}
]
}
}
Monitoring¶
Model Performance Monitoring¶
from evidently.metric_preset import DataDriftPreset, TargetDriftPreset
from evidently.report import Report
def monitor_model_drift(reference_data, production_data):
"""Monitor for data and prediction drift."""
report = Report(metrics=[
DataDriftPreset(),
TargetDriftPreset()
])
report.run(
reference_data=reference_data,
current_data=production_data
)
drift_detected = report.as_dict()["metrics"][0]["result"]["dataset_drift"]
return drift_detected, report
Inference Logging¶
# Log inference requests and predictions
import json
from datetime import datetime
def log_inference(request, prediction, model_version):
"""Log inference for monitoring."""
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"model_version": model_version,
"request": request,
"prediction": prediction
}
# Write to Delta table for analysis
spark.createDataFrame([log_entry]).write \
.format("delta") \
.mode("append") \
.save("/delta/inference_logs")
Best Practices Summary¶
| Area | Best Practice |
|---|---|
| Experiments | Use descriptive run names and tags |
| Models | Version all models in registry |
| Features | Use Feature Store for consistency |
| Serving | Enable auto-scaling and monitoring |
| Monitoring | Track drift and performance metrics |
Related Documentation¶
Last Updated: January 2025