Skip to content

🤖 MLflow Integration

Status Feature Complexity

End-to-end machine learning lifecycle management with MLflow for experiment tracking, model registry, and deployment on Azure Databricks.


🌟 Overview

MLflow is an open-source platform for managing the complete machine learning lifecycle, including experimentation, reproducibility, deployment, and a central model registry. Azure Databricks provides a fully managed and integrated MLflow experience with enhanced security, scalability, and enterprise features.

🔥 Key Benefits

  • Experiment Tracking: Log parameters, metrics, and artifacts automatically
  • Model Registry: Centralized model management with versioning and lineage
  • Reproducibility: Track code, data, and environment for every run
  • Deployment: Deploy models to various endpoints (batch, real-time, streaming)
  • Collaboration: Share experiments and models across teams
  • Integration: Native integration with Azure ML and other Azure services

🏗️ Architecture

MLflow Components

graph TB
    subgraph "Development"
        Notebook[Notebooks]
        IDE[IDEs]
        Auto[AutoML]
    end

    subgraph "MLflow Tracking"
        Exp[Experiments]
        Runs[Runs]
        Metrics[Metrics &<br/>Parameters]
        Artifacts[Artifacts &<br/>Models]
    end

    subgraph "MLflow Models"
        Registry[Model<br/>Registry]
        Versions[Model<br/>Versions]
        Stages[Lifecycle<br/>Stages]
    end

    subgraph "Deployment"
        Batch[Batch<br/>Inference]
        RT[Real-Time<br/>Serving]
        Stream[Streaming<br/>Inference]
    end

    subgraph "Monitoring"
        Perf[Performance<br/>Monitoring]
        Drift[Data Drift<br/>Detection]
        Alerts[Alerting]
    end

    Notebook --> Exp
    IDE --> Exp
    Auto --> Exp

    Exp --> Runs
    Runs --> Metrics
    Runs --> Artifacts

    Artifacts --> Registry
    Registry --> Versions
    Versions --> Stages

    Stages --> Batch
    Stages --> RT
    Stages --> Stream

    Batch --> Perf
    RT --> Perf
    Stream --> Perf
    Perf --> Drift
    Drift --> Alerts

ML Lifecycle with MLflow

graph LR
    A[Data<br/>Preparation] --> B[Model<br/>Training]
    B --> C[Experiment<br/>Tracking]
    C --> D[Model<br/>Evaluation]
    D --> E{Good<br/>Enough?}
    E -->|No| B
    E -->|Yes| F[Register<br/>Model]
    F --> G[Stage to<br/>Production]
    G --> H[Deploy<br/>Model]
    H --> I[Monitor<br/>Performance]
    I --> J{Drift<br/>Detected?}
    J -->|Yes| A
    J -->|No| I

🚀 Getting Started

MLflow Tracking

Basic Experiment Tracking

import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score

# Set experiment name
mlflow.set_experiment("/Users/data-scientist@company.com/customer-churn")

# Enable autologging (automatically logs parameters, metrics, and model)
mlflow.sklearn.autolog()

# Load data
df = spark.table("production.ml.customer_features").toPandas()
X = df.drop('churn', axis=1)
y = df['churn']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Start MLflow run
with mlflow.start_run(run_name="random-forest-baseline"):
    # Log parameters manually (if not using autolog)
    mlflow.log_param("test_size", 0.2)
    mlflow.log_param("random_state", 42)

    # Train model
    model = RandomForestClassifier(
        n_estimators=100,
        max_depth=10,
        random_state=42
    )
    model.fit(X_train, y_train)

    # Make predictions
    y_pred = model.predict(X_test)

    # Log metrics
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)

    mlflow.log_metric("accuracy", accuracy)
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)

    # Log additional artifacts
    import matplotlib.pyplot as plt
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

    cm = confusion_matrix(y_test, y_pred)
    disp = ConfusionMatrixDisplay(cm, display_labels=['Not Churned', 'Churned'])
    disp.plot()
    plt.savefig("confusion_matrix.png")
    mlflow.log_artifact("confusion_matrix.png")

    # Log feature importance
    import pandas as pd
    feature_importance = pd.DataFrame({
        'feature': X.columns,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=False)

    feature_importance.to_csv("feature_importance.csv", index=False)
    mlflow.log_artifact("feature_importance.csv")

    # Log model
    mlflow.sklearn.log_model(
        model,
        "model",
        registered_model_name="customer-churn-predictor"
    )

    print(f"Run ID: {mlflow.active_run().info.run_id}")
    print(f"Accuracy: {accuracy:.4f}")

Advanced Tracking with Custom Metrics

import mlflow
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt

with mlflow.start_run(run_name="advanced-metrics"):
    # ... training code ...

    # Log custom metrics over time (e.g., learning curves)
    for epoch in range(100):
        train_loss = model.train_loss_history[epoch]
        val_loss = model.val_loss_history[epoch]

        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metric("val_loss", val_loss, step=epoch)

    # Log probability predictions for AUC
    y_proba = model.predict_proba(X_test)[:, 1]
    auc = roc_auc_score(y_test, y_proba)
    mlflow.log_metric("auc", auc)

    # Log ROC curve
    fpr, tpr, thresholds = roc_curve(y_test, y_proba)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    plt.savefig("roc_curve.png")
    mlflow.log_artifact("roc_curve.png")

    # Log dataset
    mlflow.log_input(
        mlflow.data.from_spark(
            spark.table("production.ml.customer_features"),
            table_name="production.ml.customer_features",
            version="v1.0"
        ),
        context="training"
    )

📦 Model Registry

Register Models

import mlflow
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Option 1: Register during training
with mlflow.start_run():
    # ... training code ...
    mlflow.sklearn.log_model(
        model,
        "model",
        registered_model_name="customer-churn-predictor"
    )

# Option 2: Register existing run
run_id = "abc123xyz"
model_uri = f"runs:/{run_id}/model"

mlflow.register_model(
    model_uri=model_uri,
    name="customer-churn-predictor",
    tags={"team": "data-science", "algorithm": "random-forest"}
)

# Option 3: Register from artifact location
model_uri = "dbfs:/mnt/models/churn_model"
result = mlflow.register_model(model_uri, "customer-churn-predictor")

print(f"Model registered: {result.name}")
print(f"Version: {result.version}")

Manage Model Versions

from mlflow.tracking import MlflowClient

client = MlflowClient()
model_name = "customer-churn-predictor"

# List all versions
versions = client.search_model_versions(f"name='{model_name}'")
for v in versions:
    print(f"Version {v.version}: Stage={v.current_stage}, Run ID={v.run_id}")

# Get specific version
version = client.get_model_version(model_name, version=3)
print(f"Version 3 description: {version.description}")

# Update version description
client.update_model_version(
    name=model_name,
    version=3,
    description="Random Forest model with optimized hyperparameters. Accuracy: 0.92"
)

# Add tags to version
client.set_model_version_tag(
    name=model_name,
    version=3,
    key="validation_accuracy",
    value="0.92"
)

# Delete version (if not in production)
client.delete_model_version(
    name=model_name,
    version=1
)

Model Stages & Lifecycle

from mlflow.tracking import MlflowClient

client = MlflowClient()
model_name = "customer-churn-predictor"

# Transition to Staging
client.transition_model_version_stage(
    name=model_name,
    version=3,
    stage="Staging",
    archive_existing_versions=False
)

# Transition to Production (after validation)
client.transition_model_version_stage(
    name=model_name,
    version=3,
    stage="Production",
    archive_existing_versions=True  # Archive previous production models
)

# Archive old version
client.transition_model_version_stage(
    name=model_name,
    version=2,
    stage="Archived"
)

# Get current production model
production_versions = client.get_latest_versions(model_name, stages=["Production"])
if production_versions:
    prod_version = production_versions[0]
    print(f"Production model: Version {prod_version.version}")

🚀 Model Deployment

Batch Inference

import mlflow

# Load production model
model_name = "customer-churn-predictor"
model_uri = f"models:/{model_name}/Production"
model = mlflow.pyfunc.load_model(model_uri)

# Load data for scoring
input_data = spark.table("production.ml.customer_features_latest")

# Apply model as UDF
from pyspark.sql.functions import struct
from mlflow.pyfunc import spark_udf

predict_udf = spark_udf(
    spark,
    model_uri=model_uri,
    result_type="double"
)

# Score data
predictions = input_data.withColumn(
    "churn_probability",
    predict_udf(struct(*input_data.columns))
)

# Save predictions
predictions.write.format("delta").mode("overwrite").saveAsTable(
    "production.ml.churn_predictions"
)

display(predictions)

Real-Time Model Serving

# Create model serving endpoint
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedModelInput

w = WorkspaceClient()

# Create endpoint
endpoint_name = "churn-predictor-endpoint"

w.serving_endpoints.create(
    name=endpoint_name,
    config=EndpointCoreConfigInput(
        served_models=[
            ServedModelInput(
                model_name="customer-churn-predictor",
                model_version="3",
                workload_size="Small",
                scale_to_zero_enabled=True
            )
        ]
    )
)

# Wait for endpoint to be ready
import time
while True:
    endpoint = w.serving_endpoints.get(endpoint_name)
    if endpoint.state.ready == "READY":
        print(f"Endpoint {endpoint_name} is ready!")
        break
    time.sleep(30)

Invoke endpoint:

import requests
import json

# Get endpoint URL
endpoint_url = f"https://{workspace_url}/serving-endpoints/{endpoint_name}/invocations"

# Prepare input
input_data = {
    "dataframe_records": [
        {
            "age": 35,
            "tenure": 24,
            "monthly_charges": 79.99,
            "total_charges": 1919.76
        }
    ]
}

# Get token
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# Make prediction
headers = {
    "Authorization": f"Bearer {token}",
    "Content-Type": "application/json"
}

response = requests.post(endpoint_url, headers=headers, json=input_data)
predictions = response.json()

print(f"Prediction: {predictions}")

Streaming Inference

import mlflow

# Load model
model_uri = "models:/customer-churn-predictor/Production"
predict_udf = mlflow.pyfunc.spark_udf(spark, model_uri, result_type="double")

# Read streaming data
streaming_df = spark.readStream \
    .format("delta") \
    .table("production.streaming.customer_events")

# Apply model
predictions = streaming_df.withColumn(
    "churn_risk",
    predict_udf(struct(*feature_columns))
)

# Write predictions to Delta
query = predictions.writeStream \
    .format("delta") \
    .outputMode("append") \
    .option("checkpointLocation", "/mnt/checkpoints/churn_predictions") \
    .table("production.streaming.churn_predictions")

query.awaitTermination()

🤖 AutoML Integration

Databricks AutoML

from databricks import automl

# Prepare data
df = spark.table("production.ml.customer_features")

# Run AutoML
summary = automl.classify(
    dataset=df,
    target_col="churn",
    primary_metric="f1",
    timeout_minutes=30,
    max_trials=20
)

# Access best run
best_run = summary.best_trial

print(f"Best run ID: {best_run.mlflow_run_id}")
print(f"Best F1 score: {best_run.metrics['val_f1_score']}")

# Register best model
model_uri = f"runs:/{best_run.mlflow_run_id}/model"
mlflow.register_model(
    model_uri=model_uri,
    name="customer-churn-predictor-automl"
)

Custom AutoML Pipeline

import mlflow
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score

# Define search space
search_space = {
    'n_estimators': hp.choice('n_estimators', [50, 100, 200, 300]),
    'max_depth': hp.choice('max_depth', [5, 10, 15, 20, None]),
    'min_samples_split': hp.uniform('min_samples_split', 0.01, 0.1),
    'min_samples_leaf': hp.uniform('min_samples_leaf', 0.01, 0.1)
}

# Objective function
def objective(params):
    with mlflow.start_run(nested=True):
        # Log parameters
        mlflow.log_params(params)

        # Train model
        model = RandomForestClassifier(**params, random_state=42)
        model.fit(X_train, y_train)

        # Evaluate
        y_pred = model.predict(X_test)
        f1 = f1_score(y_test, y_pred)

        # Log metrics
        mlflow.log_metric("f1_score", f1)

        # Log model
        mlflow.sklearn.log_model(model, "model")

        return {'loss': -f1, 'status': STATUS_OK}

# Run hyperparameter tuning
with mlflow.start_run(run_name="hyperparameter-tuning"):
    trials = Trials()
    best_params = fmin(
        fn=objective,
        space=search_space,
        algo=tpe.suggest,
        max_evals=50,
        trials=trials
    )

    print(f"Best parameters: {best_params}")

📊 Feature Store

Create Feature Tables

from databricks.feature_store import FeatureStoreClient

fs = FeatureStoreClient()

# Create feature table
customer_features = spark.table("production.ml.customer_features")

fs.create_table(
    name="production.ml.customer_feature_store",
    primary_keys=["customer_id"],
    df=customer_features,
    description="Customer behavioral features for churn prediction",
    tags={"team": "data-science", "domain": "customer-analytics"}
)

Update Features

from databricks.feature_store import FeatureStoreClient
from pyspark.sql.functions import col, avg, count, datediff, current_date

fs = FeatureStoreClient()

# Calculate new features
new_features = spark.table("production.sales.transactions") \
    .groupBy("customer_id") \
    .agg(
        count("transaction_id").alias("transaction_count_30d"),
        avg("amount").alias("avg_transaction_value_30d"),
        datediff(current_date(), max("transaction_date")).alias("days_since_last_purchase")
    )

# Write features
fs.write_table(
    name="production.ml.customer_feature_store",
    df=new_features,
    mode="merge"
)

Train with Feature Store

from databricks.feature_store import FeatureStoreClient, FeatureLookup

fs = FeatureStoreClient()

# Define feature lookups
feature_lookups = [
    FeatureLookup(
        table_name="production.ml.customer_feature_store",
        lookup_key="customer_id"
    ),
    FeatureLookup(
        table_name="production.ml.product_features",
        lookup_key="product_id"
    )
]

# Create training dataset
training_set = fs.create_training_set(
    df=spark.table("production.ml.labels"),
    feature_lookups=feature_lookups,
    label="churn",
    exclude_columns=["customer_id"]
)

# Load as pandas
training_df = training_set.load_df().toPandas()

# Train model
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier

with mlflow.start_run():
    model = RandomForestClassifier()
    model.fit(training_df.drop("churn", axis=1), training_df["churn"])

    # Log model with feature store metadata
    fs.log_model(
        model=model,
        artifact_path="model",
        flavor=mlflow.sklearn,
        training_set=training_set,
        registered_model_name="churn-predictor-with-features"
    )

Batch Scoring with Features

from databricks.feature_store import FeatureStoreClient

fs = FeatureStoreClient()

# Load model
model_uri = "models:/churn-predictor-with-features/Production"

# Score using feature store
predictions = fs.score_batch(
    model_uri=model_uri,
    df=spark.table("production.ml.customers_to_score"),
    result_type="double"
)

display(predictions)

📈 Model Monitoring

Track Model Performance

import mlflow

# Log production metrics
with mlflow.start_run(run_name="production-monitoring"):
    # Calculate metrics on production data
    production_data = spark.table("production.ml.churn_predictions_last_30d")

    actual_churn = production_data.filter(col("actual_churn") == 1).count()
    predicted_churn = production_data.filter(col("churn_probability") > 0.5).count()

    mlflow.log_metric("actual_churn_rate", actual_churn / production_data.count())
    mlflow.log_metric("predicted_churn_rate", predicted_churn / production_data.count())

    # Log drift metrics
    from scipy.stats import ks_2samp

    training_data = spark.table("production.ml.training_data_snapshot")

    for column in numerical_columns:
        stat, p_value = ks_2samp(
            training_data.select(column).toPandas()[column],
            production_data.select(column).toPandas()[column]
        )
        mlflow.log_metric(f"drift_{column}_ks_statistic", stat)
        mlflow.log_metric(f"drift_{column}_p_value", p_value)

Data Drift Detection

from pyspark.sql.functions import col, mean, stddev

# Calculate statistics for training data
training_stats = spark.table("production.ml.training_data") \
    .select([
        mean(col(c)).alias(f"{c}_mean") for c in numerical_columns
    ] + [
        stddev(col(c)).alias(f"{c}_stddev") for c in numerical_columns
    ]).collect()[0]

# Calculate statistics for production data
production_stats = spark.table("production.ml.production_data_latest") \
    .select([
        mean(col(c)).alias(f"{c}_mean") for c in numerical_columns
    ] + [
        stddev(col(c)).alias(f"{c}_stddev") for c in numerical_columns
    ]).collect()[0]

# Detect drift
drift_detected = False
for column in numerical_columns:
    train_mean = training_stats[f"{column}_mean"]
    prod_mean = production_stats[f"{column}_mean"]

    # Simple threshold-based drift detection
    drift_pct = abs(prod_mean - train_mean) / train_mean * 100

    print(f"{column}: {drift_pct:.2f}% drift")

    if drift_pct > 10:  # 10% threshold
        drift_detected = True
        print(f"⚠️ Significant drift detected in {column}")

if drift_detected:
    # Trigger retraining
    dbutils.notebook.run("/path/to/retraining_notebook", timeout_seconds=3600)

🔧 Best Practices

Experiment Organization

# Use hierarchical experiment naming
mlflow.set_experiment("/teams/data-science/customer-churn/q1-2025")

# Tag runs for easy filtering
with mlflow.start_run():
    mlflow.set_tags({
        "team": "data-science",
        "project": "customer-churn",
        "model_type": "classification",
        "algorithm": "random-forest",
        "environment": "production"
    })

    # ... training code ...

# Search experiments
from mlflow.tracking import MlflowClient

client = MlflowClient()
experiments = client.search_experiments(
    filter_string="tags.project = 'customer-churn'"
)

Model Versioning Strategy

Version Naming Convention:
  - Major version: Significant algorithm change
  - Minor version: Hyperparameter tuning
  - Patch version: Bug fixes

Model Stages:
  - None: Initial registration
  - Staging: Under validation
  - Production: Live in production
  - Archived: Deprecated or superseded

Tags to Include:
  - algorithm: random-forest, xgboost, neural-network
  - accuracy: 0.92
  - training_date: 2025-01-28
  - data_version: v1.5
  - team: data-science
  - jira_ticket: ML-1234

MLOps Workflow

# 1. Development
with mlflow.start_run(experiment_id="dev-experiment"):
    # Train and log model
    mlflow.sklearn.log_model(model, "model")

# 2. Register to staging
run_id = mlflow.active_run().info.run_id
result = mlflow.register_model(f"runs:/{run_id}/model", "my-model")

client.transition_model_version_stage(
    name="my-model",
    version=result.version,
    stage="Staging"
)

# 3. Validate in staging
# Run validation tests...

# 4. Promote to production
if validation_passed:
    client.transition_model_version_stage(
        name="my-model",
        version=result.version,
        stage="Production",
        archive_existing_versions=True
    )

# 5. Monitor production performance
# Set up monitoring dashboards and alerts

🆘 Troubleshooting

Common Issues

Issue: Model artifacts not loading

# Solution: Check artifact location
import mlflow

run_id = "abc123xyz"
run = mlflow.get_run(run_id)
print(f"Artifact URI: {run.info.artifact_uri}")

# Verify artifacts exist
artifacts = mlflow.artifacts.list_artifacts(run_id)
for artifact in artifacts:
    print(f"  - {artifact.path}")

Issue: Model serving endpoint fails

# Solution: Check endpoint logs
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
endpoint = w.serving_endpoints.get("endpoint-name")

print(f"State: {endpoint.state.ready}")
print(f"Config: {endpoint.state.config_update}")

# Get logs
logs = w.serving_endpoints.logs(name="endpoint-name", served_model_name="model-name")
for log in logs.logs:
    print(log)

Issue: Feature store lookup fails

# Solution: Verify feature table exists and has correct schema
from databricks.feature_store import FeatureStoreClient

fs = FeatureStoreClient()

# Check feature table
feature_table = fs.get_table("catalog.schema.feature_table")
print(f"Primary keys: {feature_table.primary_keys}")
print(f"Features: {feature_table.features}")

# Verify lookup key exists in both tables
lookup_df = spark.table("catalog.schema.lookup_table")
feature_df = spark.table("catalog.schema.feature_table")

print(f"Lookup keys in lookup_df: {lookup_df.select('lookup_key').distinct().count()}")
print(f"Lookup keys in feature_df: {feature_df.select('lookup_key').distinct().count()}")


🎯 Next Steps

  1. Set Up Feature Store - Centralize features
  2. Implement Model Monitoring - Track performance
  3. Deploy Production Pipeline - Automate MLOps

Last Updated: 2025-01-28 MLflow Version: 2.9+ Documentation Status: Complete