Skip to content

🎯 Spark Notebook Interactive Sandbox

Status: Active Type: Interactive Difficulty: Intermediate

📋 Overview

A fully interactive Spark notebook environment for experimenting with PySpark, Scala, and SparkSQL. This sandbox provides a safe, isolated environment to learn Spark concepts, test code, and explore data processing patterns without affecting production resources.

Duration: Self-paced Format: Live coding environment with pre-loaded datasets Prerequisites: Basic Python or Scala knowledge

🎯 Learning Objectives

By using this interactive sandbox, you will:

  • Write and execute PySpark code in a real Spark environment
  • Understand DataFrame operations and transformations
  • Practice data manipulation and aggregation techniques
  • Learn Spark optimization strategies
  • Experiment with different data formats (Parquet, Delta, CSV, JSON)
  • Explore Spark SQL and DataFrame API equivalence

🚀 Prerequisites and Setup

Access Requirements

  • Browser-Based: Modern web browser with JavaScript enabled
  • No Installation: Runs entirely in your browser
  • Sample Data: Pre-loaded datasets available
  • Save Progress: Optional account creation for saving work

Environment Specifications

spark_config:
  version: "3.4.0"
  scala_version: "2.12"
  python_version: "3.10"

resources:
  driver_memory: "4g"
  executor_memory: "4g"
  executor_cores: 2
  num_executors: 2

packages:
  - delta-core_2.12:2.4.0
  - azure-storage:8.6.6
  - mssql-jdbc:11.2.0

Quick Start

# Your first Spark code
# This runs immediately in the sandbox

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

# Spark session is already initialized as 'spark'
print(f"Spark Version: {spark.version}")
print(f"Python Version: {sys.version}")

# Load sample data
df = spark.read.parquet("/samples/sales_data.parquet")
df.show(5)

🎮 Interactive Features

Code Editor

// Monaco Editor Configuration
const editorConfig = {
  language: 'python',
  theme: 'vs-dark',
  automaticLayout: true,
  minimap: { enabled: true },
  fontSize: 14,
  wordWrap: 'on',
  scrollBeyondLastLine: false,

  // IntelliSense configuration
  suggest: {
    snippetsPreventQuickSuggestions: false
  },

  // PySpark-specific features
  quickSuggestions: {
    other: true,
    comments: false,
    strings: true
  }
};

// Code completion for PySpark
const pysparkCompletions = [
  {
    label: 'read.parquet',
    insertText: 'spark.read.parquet("${1:path}")',
    documentation: 'Read Parquet file into DataFrame'
  },
  {
    label: 'groupBy.agg',
    insertText: 'groupBy("${1:column}").agg(${2:aggregation})',
    documentation: 'Group by column and aggregate'
  }
];

Live Data Visualization

# Built-in visualization
from pyspark.sql.functions import col, sum, count

# Sample sales analysis
sales_by_category = df.groupBy("category") \
    .agg(
        sum("revenue").alias("total_revenue"),
        count("*").alias("transaction_count")
    ) \
    .orderBy(col("total_revenue").desc())

# Display with built-in charting
display(sales_by_category)
# Automatically generates: bar chart, line chart, pie chart, table

# Custom visualization with matplotlib
import matplotlib.pyplot as plt

data = sales_by_category.toPandas()
plt.figure(figsize=(12, 6))
plt.bar(data['category'], data['total_revenue'])
plt.xlabel('Category')
plt.ylabel('Revenue')
plt.title('Revenue by Category')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

📚 Pre-Loaded Sample Datasets

Available Datasets

# Dataset catalog
datasets = {
    "sales": {
        "path": "/samples/sales_data.parquet",
        "format": "parquet",
        "rows": 1000000,
        "schema": """
            transaction_id: string
            timestamp: timestamp
            customer_id: string
            product_id: string
            category: string
            quantity: integer
            price: decimal(10,2)
            revenue: decimal(10,2)
            region: string
        """
    },

    "customers": {
        "path": "/samples/customers.parquet",
        "format": "parquet",
        "rows": 50000,
        "schema": """
            customer_id: string
            name: string
            email: string
            signup_date: date
            region: string
            segment: string
        """
    },

    "products": {
        "path": "/samples/products.parquet",
        "format": "parquet",
        "rows": 10000,
        "schema": """
            product_id: string
            name: string
            category: string
            subcategory: string
            price: decimal(10,2)
            cost: decimal(10,2)
        """
    },

    "clickstream": {
        "path": "/samples/clickstream.json",
        "format": "json",
        "rows": 5000000,
        "schema": """
            session_id: string
            user_id: string
            timestamp: timestamp
            page_url: string
            action: string
            device_type: string
        """
    }
}

# Helper function to load datasets
def load_dataset(name):
    """Load a pre-configured sample dataset"""
    if name not in datasets:
        raise ValueError(f"Dataset '{name}' not found")

    config = datasets[name]
    df = spark.read.format(config["format"]).load(config["path"])
    print(f"Loaded {name}: {df.count()} rows")
    return df

# Usage
sales_df = load_dataset("sales")
customers_df = load_dataset("customers")

🎓 Interactive Tutorials

Tutorial 1: DataFrame Basics

# ============================================
# Tutorial 1: DataFrame Basics
# ============================================

# Step 1: Load data
df = spark.read.parquet("/samples/sales_data.parquet")

# Step 2: Explore schema
print("Schema:")
df.printSchema()

# Step 3: Basic operations
print("\nFirst 10 rows:")
df.show(10)

print("\nColumn names:")
print(df.columns)

print("\nDataFrame shape:")
print(f"Rows: {df.count()}, Columns: {len(df.columns)}")

# Step 4: Select columns
selected_df = df.select("transaction_id", "customer_id", "revenue")
selected_df.show(5)

# Step 5: Filter data
high_value_df = df.filter(col("revenue") > 1000)
print(f"\nHigh value transactions: {high_value_df.count()}")

# Step 6: Sort data
top_transactions = df.orderBy(col("revenue").desc()).limit(10)
top_transactions.show()

# ✅ Exercise: Find all transactions in the 'Electronics' category
# where quantity > 2. Sort by revenue descending.
# Write your code below:

Tutorial 2: Aggregations and GroupBy

# ============================================
# Tutorial 2: Aggregations and GroupBy
# ============================================

from pyspark.sql.functions import *

# Load data
df = spark.read.parquet("/samples/sales_data.parquet")

# Basic aggregation
print("Total Revenue:", df.agg(sum("revenue")).collect()[0][0])

# Group by single column
category_sales = df.groupBy("category") \
    .agg(
        sum("revenue").alias("total_revenue"),
        avg("revenue").alias("avg_revenue"),
        count("*").alias("transaction_count"),
        min("revenue").alias("min_revenue"),
        max("revenue").alias("max_revenue")
    ) \
    .orderBy(col("total_revenue").desc())

category_sales.show()

# Group by multiple columns
region_category_sales = df.groupBy("region", "category") \
    .agg(
        sum("revenue").alias("total_revenue"),
        countDistinct("customer_id").alias("unique_customers")
    ) \
    .orderBy("region", col("total_revenue").desc())

region_category_sales.show()

# Advanced: Window functions
from pyspark.sql.window import Window

# Calculate running total by region
window_spec = Window.partitionBy("region").orderBy("timestamp")

df_with_running_total = df.withColumn(
    "running_total",
    sum("revenue").over(window_spec)
)

df_with_running_total.select(
    "timestamp", "region", "revenue", "running_total"
).show(20)

# Ranking
window_rank = Window.partitionBy("category").orderBy(col("revenue").desc())

top_per_category = df.withColumn(
    "rank",
    row_number().over(window_rank)
).filter(col("rank") <= 5)

top_per_category.show()

# ✅ Exercise: Calculate the average revenue per customer
# for each region. Show only regions with avg > $500.
# Write your code below:

Tutorial 3: Joins and Complex Transformations

# ============================================
# Tutorial 3: Joins and Complex Transformations
# ============================================

# Load datasets
sales_df = load_dataset("sales")
customers_df = load_dataset("customers")
products_df = load_dataset("products")

# Inner Join
enriched_sales = sales_df \
    .join(customers_df, "customer_id", "inner") \
    .join(products_df, "product_id", "inner") \
    .select(
        "transaction_id",
        "timestamp",
        customers_df["name"].alias("customer_name"),
        customers_df["segment"],
        products_df["name"].alias("product_name"),
        "quantity",
        "revenue"
    )

enriched_sales.show(10)

# Left Join (to include all sales even without customer data)
all_sales = sales_df \
    .join(customers_df, "customer_id", "left") \
    .select(
        "transaction_id",
        "customer_id",
        customers_df["name"].alias("customer_name"),
        "revenue"
    )

# Count missing customer data
missing_customers = all_sales.filter(col("customer_name").isNull()).count()
print(f"Transactions with missing customer data: {missing_customers}")

# Complex transformation: Customer Lifetime Value
from pyspark.sql.functions import *

customer_ltv = enriched_sales \
    .groupBy("customer_id", "customer_name", "segment") \
    .agg(
        sum("revenue").alias("lifetime_value"),
        count("transaction_id").alias("purchase_count"),
        avg("revenue").alias("avg_order_value"),
        min("timestamp").alias("first_purchase"),
        max("timestamp").alias("last_purchase")
    ) \
    .withColumn(
        "customer_tenure_days",
        datediff(col("last_purchase"), col("first_purchase"))
    ) \
    .withColumn(
        "purchase_frequency",
        col("purchase_count") / (col("customer_tenure_days") + 1)
    ) \
    .orderBy(col("lifetime_value").desc())

customer_ltv.show(20)

# Self-join: Find customers who bought the same product twice
repeat_purchases = sales_df.alias("s1") \
    .join(
        sales_df.alias("s2"),
        (col("s1.customer_id") == col("s2.customer_id")) &
        (col("s1.product_id") == col("s2.product_id")) &
        (col("s1.transaction_id") != col("s2.transaction_id"))
    ) \
    .select(
        col("s1.customer_id"),
        col("s1.product_id"),
        col("s1.timestamp").alias("first_purchase"),
        col("s2.timestamp").alias("second_purchase")
    )

print(f"Repeat purchases: {repeat_purchases.count()}")

# ✅ Exercise: Find the top 10 customers by lifetime value
# in the 'Premium' segment. Include their average order value
# and purchase frequency.
# Write your code below:

Tutorial 4: Performance Optimization

# ============================================
# Tutorial 4: Performance Optimization
# ============================================

from pyspark.sql.functions import *
import time

# Load large dataset
df = spark.read.parquet("/samples/sales_data.parquet")

# ===== Optimization 1: Caching =====
print("Without caching:")
start = time.time()
df.count()
print(f"First count: {time.time() - start:.2f} seconds")

start = time.time()
df.count()
print(f"Second count: {time.time() - start:.2f} seconds")

# Enable caching
df.cache()
df.count()  # Trigger caching

print("\nWith caching:")
start = time.time()
df.count()
print(f"Cached count: {time.time() - start:.2f} seconds")

# ===== Optimization 2: Partitioning =====
# Check current partitions
print(f"\nCurrent partitions: {df.rdd.getNumPartitions()}")

# Repartition for better parallelism
df_repartitioned = df.repartition(8, "region")
print(f"After repartitioning: {df_repartitioned.rdd.getNumPartitions()}")

# ===== Optimization 3: Broadcast Joins =====
from pyspark.sql.functions import broadcast

# Small table
small_df = spark.createDataFrame([
    ("Electronics", 0.1),
    ("Clothing", 0.15),
    ("Food", 0.05)
], ["category", "tax_rate"])

# Regular join (can be slow)
regular_join = df.join(small_df, "category")

# Broadcast join (faster for small tables)
broadcast_join = df.join(broadcast(small_df), "category")

# ===== Optimization 4: Predicate Pushdown =====
# Filter early to reduce data movement
filtered_df = df \
    .filter(col("timestamp") >= "2024-01-01") \
    .filter(col("region") == "North") \
    .select("transaction_id", "revenue")

# ===== Optimization 5: Column Pruning =====
# Select only needed columns early
pruned_df = df.select("customer_id", "revenue", "timestamp")

# Then perform operations
customer_summary = pruned_df \
    .groupBy("customer_id") \
    .agg(sum("revenue").alias("total_revenue"))

# ===== Monitoring and Explain Plan =====
# View execution plan
df.filter(col("revenue") > 1000).explain(True)

# ✅ Exercise: Optimize this query
# Original (slow):
slow_query = df \
    .join(customers_df, "customer_id") \
    .filter(col("revenue") > 500) \
    .filter(col("region") == "East") \
    .select("transaction_id", "revenue")

# Optimize and compare execution plans:
# Write your optimized version below:

Tutorial 5: Working with Different Data Formats

# ============================================
# Tutorial 5: Data Formats and Delta Lake
# ============================================

from delta.tables import *

# ===== Reading Different Formats =====

# CSV
csv_df = spark.read \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .csv("/samples/data.csv")

# JSON
json_df = spark.read.json("/samples/data.json")

# Parquet
parquet_df = spark.read.parquet("/samples/data.parquet")

# Delta
delta_df = spark.read.format("delta").load("/samples/delta_table")

# ===== Writing Data =====

# Write as Parquet with compression
df.write \
    .mode("overwrite") \
    .option("compression", "snappy") \
    .parquet("/output/sales_parquet")

# Write as Delta with partitioning
df.write \
    .format("delta") \
    .mode("overwrite") \
    .partitionBy("region", "category") \
    .save("/output/sales_delta")

# ===== Delta Lake Operations =====

# Create Delta table
df.write \
    .format("delta") \
    .mode("overwrite") \
    .save("/delta/sales")

# Read Delta table
delta_sales = spark.read.format("delta").load("/delta/sales")

# Update operation
deltaTable = DeltaTable.forPath(spark, "/delta/sales")

deltaTable.update(
    condition = col("revenue") < 0,
    set = { "revenue": lit(0) }
)

# Delete operation
deltaTable.delete(condition = col("quantity") == 0)

# Merge (UPSERT) operation
new_data = spark.createDataFrame([
    ("TXN001", 150.00),
    ("TXN999", 200.00)
], ["transaction_id", "revenue"])

deltaTable.alias("target") \
    .merge(
        new_data.alias("source"),
        "target.transaction_id = source.transaction_id"
    ) \
    .whenMatchedUpdate(set = { "revenue": col("source.revenue") }) \
    .whenNotMatchedInsertAll() \
    .execute()

# Time travel
# Read historical version
historical_df = spark.read \
    .format("delta") \
    .option("versionAsOf", 0) \
    .load("/delta/sales")

# Read as of timestamp
timestamp_df = spark.read \
    .format("delta") \
    .option("timestampAsOf", "2024-01-01") \
    .load("/delta/sales")

# View history
deltaTable.history().show()

# ✅ Exercise: Create a Delta table from sales data,
# perform an update to fix negative revenues,
# and query the table before and after the update using time travel.
# Write your code below:

💡 Code Snippets Library

Quick Actions

# ===== Common Operations =====

# 1. Schema operations
df.printSchema()
df.dtypes
df.columns

# 2. Data inspection
df.show(20, truncate=False)
df.describe().show()
df.summary().show()

# 3. Column operations
df.withColumn("new_col", lit("value"))
df.withColumnRenamed("old_name", "new_name")
df.drop("column_name")

# 4. Null handling
df.na.drop()
df.na.fill({"column": "default_value"})
df.filter(col("column").isNotNull())

# 5. String operations
df.filter(col("name").like("%pattern%"))
df.withColumn("upper_name", upper(col("name")))
df.withColumn("length", length(col("name")))

# 6. Date operations
df.withColumn("year", year(col("date_column")))
df.withColumn("month", month(col("date_column")))
df.withColumn("date_diff", datediff(col("end_date"), col("start_date")))

# 7. Type conversions
df.withColumn("price_int", col("price").cast("integer"))
df.withColumn("timestamp", to_timestamp(col("date_string"), "yyyy-MM-dd"))

# 8. UDFs (User Defined Functions)
from pyspark.sql.types import StringType

def categorize_revenue(revenue):
    if revenue < 100:
        return "Low"
    elif revenue < 500:
        return "Medium"
    else:
        return "High"

categorize_udf = udf(categorize_revenue, StringType())
df.withColumn("revenue_category", categorize_udf(col("revenue")))

🔧 Advanced Features

Spark SQL Integration

# Register DataFrame as temp view
df.createOrReplaceTempView("sales")

# Use SQL to query
result = spark.sql("""
    SELECT
        region,
        category,
        SUM(revenue) as total_revenue,
        COUNT(*) as transaction_count
    FROM sales
    WHERE timestamp >= '2024-01-01'
    GROUP BY region, category
    HAVING SUM(revenue) > 10000
    ORDER BY total_revenue DESC
""")

result.show()

# Mix SQL and DataFrame API
spark.sql("SELECT * FROM sales WHERE region = 'North'") \
    .groupBy("category") \
    .agg(avg("revenue")) \
    .show()

Streaming Simulation

# Simulate streaming data
from pyspark.sql.types import *
import time

schema = StructType([
    StructField("timestamp", TimestampType(), True),
    StructField("sensor_id", StringType(), True),
    StructField("temperature", DoubleType(), True),
    StructField("humidity", DoubleType(), True)
])

# Create streaming DataFrame (simulated)
streaming_df = spark.readStream \
    .schema(schema) \
    .json("/streaming/input")

# Perform transformations
processed_df = streaming_df \
    .withColumn("temp_celsius", (col("temperature") - 32) * 5/9) \
    .filter(col("temperature") > 100)

# Write stream (simulated)
query = processed_df.writeStream \
    .format("memory") \
    .queryName("sensor_data") \
    .start()

# Query the in-memory table
spark.sql("SELECT * FROM sensor_data").show()

🔧 Troubleshooting

Common Issues

Issue: Out of Memory Errors

Symptoms:

org.apache.spark.SparkException: Job aborted due to stage failure:
java.lang.OutOfMemoryError: Java heap space

Solution:

# 1. Reduce data size
df_sample = df.sample(0.1)  # Use 10% sample

# 2. Increase partitions
df_repartitioned = df.repartition(100)

# 3. Use iterative processing
for partition in df.rdd.mapPartitions(lambda x: [list(x)]):
    process_partition(partition)

# 4. Clear cache when done
df.unpersist()

Issue: Slow Joins

Symptoms:

  • Join operations taking very long
  • Skewed data distribution

Solution:

# 1. Broadcast small tables
from pyspark.sql.functions import broadcast
result = large_df.join(broadcast(small_df), "key")

# 2. Salt keys for skewed joins
from pyspark.sql.functions import rand, concat

salted_df = df.withColumn("salted_key",
    concat(col("key"), lit("_"), (rand() * 10).cast("int")))

# 3. Increase shuffle partitions
spark.conf.set("spark.sql.shuffle.partitions", "200")

Issue: Data Type Mismatches

Solution:

# Check schema
df.printSchema()

# Cast columns
df = df.withColumn("price", col("price").cast("decimal(10,2)"))

# Handle null values before operations
df = df.na.fill({"price": 0.0})

Launch Spark Sandbox: https://demos.csa-inabox.com/spark-sandbox

Features:

  • Real-time code execution
  • Pre-loaded sample datasets
  • Save and share notebooks
  • Export results
  • Collaborative editing

📚 Additional Resources

Documentation

Code Examples

💬 Feedback

💡 How was your Spark Sandbox experience?


Last Updated: January 2025 | Version: 1.0.0