diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..d989b48 --- /dev/null +++ b/.env.example @@ -0,0 +1,43 @@ +# Copy this file to .env and fill in your values. +# Never commit the .env file. + +# Kafka +KAFKA_BOOTSTRAP_SERVERS=localhost:9092 +KAFKA_TELEMETRY_TOPIC=drone-telemetry +KAFKA_PROCESSED_TOPIC=drone-processed +KAFKA_ALERTS_TOPIC=drone-alerts +KAFKA_CONSUMER_GROUP=drone-pipeline + +# Spark +SPARK_APP_NAME=DroneDeliveryPipeline +SPARK_MASTER=local[*] +STREAMING_TRIGGER_INTERVAL=10 seconds +STREAMING_CHECKPOINT_LOCATION=/tmp/drone-pipeline-checkpoints + +# Collision detection thresholds +COLLISION_SAFE_DISTANCE_M=50.0 +COLLISION_SAFE_ALTITUDE_M=10.0 + +# Route optimisation +LOW_BATTERY_THRESHOLD=20.0 +MAX_PAYLOAD_WEIGHT_KG=5.0 + +# AWS +AWS_REGION=us-east-1 +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= + +# S3 +S3_BUCKET=drone-telemetry-data +S3_RAW_PREFIX=raw/ +S3_PROCESSED_PREFIX=processed/ +S3_ALERTS_PREFIX=alerts/ + +# Redshift +REDSHIFT_HOST= +REDSHIFT_PORT=5439 +REDSHIFT_DATABASE=drone_analytics +REDSHIFT_USER= +REDSHIFT_PASSWORD= +REDSHIFT_IAM_ROLE= +REDSHIFT_SCHEMA=public diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aa796ef --- /dev/null +++ b/.gitignore @@ -0,0 +1,41 @@ +# ---- Python ---- +__pycache__/ +*.py[cod] +*.pyo +*.pyd +.Python +*.egg-info/ +dist/ +build/ +.eggs/ +.venv/ +venv/ +env/ + +# ---- Pytest ---- +.pytest_cache/ +.coverage +htmlcov/ +coverage.xml + +# ---- Spark ---- +/tmp/ +derby.log +metastore_db/ +spark-warehouse/ + +# ---- IDE ---- +.idea/ +.vscode/ +*.swp +*.swo + +# ---- Environment ---- +.env +*.env.local + +# ---- Docker ---- +*.log + +# ---- AWS ---- +.aws/credentials diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..776b406 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.11-slim + +# Install Java (required by PySpark) +RUN apt-get update && \ + apt-get install -y --no-install-recommends default-jdk-headless && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +ENV PYTHONPATH=/app +ENV JAVA_HOME=/usr/lib/jvm/default-java diff --git a/README.md b/README.md index 11e6abd..f0ec420 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,200 @@ -# drone-delivery-optimization -Real-time drone delivery optimization using Apache Kafka, Spark, and AWS for route efficiency and collision avoidance. +# Drone Delivery Optimization – Big Data Pipeline + +Real-time drone delivery optimization using **Apache Kafka**, **Apache Spark Structured Streaming**, **AWS S3**, and **AWS Redshift** for route efficiency and collision avoidance. + +--- + +## Architecture + +``` +Drone Fleet + │ (JSON telemetry, 1 Hz per drone) + ▼ +┌─────────────────────────┐ +│ Kafka Producer │ src/kafka/drone_producer.py +│ Topic: drone-telemetry│ +└──────────┬──────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Spark Structured Streaming │ +│ │ +│ ┌─────────────────────┐ ┌───────────────────────────┐ │ +│ │ TelemetryProcessor │ │ CollisionDetector │ │ +│ │ (schema, parse) │──▶│ (self-join, safe radius) │ │ +│ └────────┬────────────┘ └─────────────┬─────────────┘ │ +│ │ │ alerts │ +│ ┌────────▼────────────┐ ▼ │ +│ │ RouteOptimizer │ Kafka topic: drone-alerts │ +│ │ (haversine, score) │ │ +│ └────────┬────────────┘ │ +└───────────┼─────────────────────────────────────────────────┘ + │ Parquet (partitioned by date) + ▼ +┌─────────────────────────┐ +│ AWS S3 │ src/aws/s3_handler.py +│ s3://bucket/processed/ │ +└──────────┬──────────────┘ + │ Redshift COPY + ▼ +┌─────────────────────────┐ +│ AWS Redshift │ src/aws/redshift_handler.py +│ drone_analytics DB │ +│ • drone_telemetry │ +│ • drone_routes │ +│ • collision_alerts │ +└─────────────────────────┘ +``` + +--- + +## Project Structure + +``` +drone-delivery-optimization/ +├── config/ +│ ├── kafka_config.py # Kafka broker / topic settings +│ ├── spark_config.py # Spark / streaming settings +│ └── aws_config.py # S3 & Redshift settings +├── src/ +│ ├── models/ +│ │ └── drone_telemetry.py # Telemetry dataclass + (de)serialisation +│ ├── kafka/ +│ │ ├── drone_producer.py # Telemetry simulator & Kafka producer +│ │ └── drone_consumer.py # Kafka consumer with pluggable handler +│ ├── spark/ +│ │ ├── telemetry_processor.py # Structured Streaming read/write +│ │ ├── route_optimizer.py # Haversine distance + action scoring +│ │ └── collision_detector.py # Self-join proximity alerts +│ └── aws/ +│ ├── s3_handler.py # S3 upload / download / list helpers +│ └── redshift_handler.py# Redshift DDL, bulk inserts, COPY, queries +├── tests/ +│ ├── test_drone_telemetry.py +│ ├── test_drone_producer.py +│ ├── test_drone_consumer.py +│ ├── test_route_optimizer.py +│ ├── test_collision_detector.py +│ ├── test_s3_handler.py +│ └── test_redshift_handler.py +├── scripts/ +│ ├── run_pipeline.py # Spark pipeline entrypoint +│ └── start_pipeline.sh # Docker Compose helper +├── docker-compose.yml # ZooKeeper, Kafka, Schema Registry, Kafka UI +├── Dockerfile +└── requirements.txt +``` + +--- + +## Quick Start + +### Prerequisites + +| Tool | Version | +|------|---------| +| Python | ≥ 3.11 | +| Docker & Docker Compose | ≥ 24 | +| Java (JDK) | ≥ 11 (required by PySpark) | + +### 1 – Install dependencies + +```bash +python -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt +``` + +### 2 – Configure environment + +Copy `.env.example` to `.env` and fill in your AWS credentials: + +```bash +cp .env.example .env +``` + +Key variables: + +| Variable | Description | Default | +|----------|-------------|---------| +| `KAFKA_BOOTSTRAP_SERVERS` | Kafka broker address | `localhost:9092` | +| `AWS_REGION` | AWS region | `us-east-1` | +| `S3_BUCKET` | Target S3 bucket | `drone-telemetry-data` | +| `REDSHIFT_HOST` | Redshift cluster endpoint | _(required for Redshift)_ | +| `REDSHIFT_USER` / `REDSHIFT_PASSWORD` | Redshift credentials | _(required for Redshift)_ | +| `COLLISION_SAFE_DISTANCE_M` | Horizontal safety radius (m) | `50` | +| `COLLISION_SAFE_ALTITUDE_M` | Vertical safety clearance (m) | `10` | + +### 3 – Start infrastructure + +```bash +./scripts/start_pipeline.sh +``` + +This starts ZooKeeper, Kafka, Schema Registry, and the Kafka UI at http://localhost:8080. + +### 4 – Run the producer (simulated drones) + +```bash +python -m src.kafka.drone_producer +``` + +### 5 – Run the Spark pipeline + +```bash +python -m scripts.run_pipeline +``` + +--- + +## Running Tests + +```bash +pytest tests/ -v --cov=src --cov-report=term-missing +``` + +Tests that exercise Spark (route optimiser, collision detector) spin up a local `SparkSession` and do **not** require a running cluster. + +--- + +## Key Components + +### Telemetry Model (`src/models/drone_telemetry.py`) + +A `@dataclass` capturing: `drone_id`, `latitude`, `longitude`, `altitude`, `speed`, `heading`, `battery_level`, `status`, `timestamp`, `destination_lat/lon`, `payload_weight`. Includes JSON serialisation helpers. + +### Kafka Producer (`src/kafka/drone_producer.py`) + +Simulates a configurable number of drones publishing telemetry at a configurable rate. Uses `confluent-kafka` with GZIP compression, idempotent delivery (`acks=all`), and automatic topic creation. + +### Spark Telemetry Processor (`src/spark/telemetry_processor.py`) + +Reads from the `drone-telemetry` Kafka topic, applies the telemetry schema, and writes Parquet files to S3 partitioned by date. Also forwards enriched records to `drone-processed`. + +### Route Optimizer (`src/spark/route_optimizer.py`) + +Adds four columns to the streaming DataFrame: +- `distance_to_dest_m` – great-circle distance to destination via the Haversine formula +- `estimated_flight_time_s` – ETA at current speed +- `recommended_action` – one of `continue`, `return_to_base`, `emergency_land`, `reduce_speed`, `optimal` +- `optimisation_score` – 0–100 efficiency score (battery × 0.5, speed × 0.3, payload × 0.2) + +### Collision Detector (`src/spark/collision_detector.py`) + +Performs a self-join on each micro-batch to find drone pairs whose **horizontal** separation is below `COLLISION_SAFE_DISTANCE_M` **and/or** whose **vertical** separation is below `COLLISION_SAFE_ALTITUDE_M`. Publishes `WARNING` / `CRITICAL` alerts to the `drone-alerts` Kafka topic. + +### AWS S3 Handler (`src/aws/s3_handler.py`) + +Provides `upload_json_records`, `upload_file`, `download_json_records`, `list_objects`, and `delete_object` helpers backed by `boto3`. Supports Hive-style date partitioning (`year=/month=/day=`). + +### AWS Redshift Handler (`src/aws/redshift_handler.py`) + +- Auto-creates `drone_telemetry`, `drone_routes`, and `collision_alerts` tables with `DISTKEY` / `SORTKEY` optimisations. +- Bulk inserts via `psycopg2` `execute_values`. +- `COPY … FROM S3` for high-throughput Parquet loads. +- Analytics helpers: `get_low_battery_drones`, `get_recent_alerts`. + +--- + +## License + +MIT diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/aws_config.py b/config/aws_config.py new file mode 100644 index 0000000..d74c586 --- /dev/null +++ b/config/aws_config.py @@ -0,0 +1,36 @@ +"""AWS configuration.""" + +import os + +# ------------------------------------------------------------------ +# Credentials (prefer IAM roles / environment variables in production) +# ------------------------------------------------------------------ + +AWS_REGION: str = os.environ.get("AWS_REGION", "us-east-1") +AWS_ACCESS_KEY_ID: str = os.environ.get("AWS_ACCESS_KEY_ID", "") +AWS_SECRET_ACCESS_KEY: str = os.environ.get("AWS_SECRET_ACCESS_KEY", "") + +# ------------------------------------------------------------------ +# S3 +# ------------------------------------------------------------------ + +S3_BUCKET: str = os.environ.get("S3_BUCKET", "drone-telemetry-data") +S3_RAW_PREFIX: str = os.environ.get("S3_RAW_PREFIX", "raw/") +S3_PROCESSED_PREFIX: str = os.environ.get("S3_PROCESSED_PREFIX", "processed/") +S3_ALERTS_PREFIX: str = os.environ.get("S3_ALERTS_PREFIX", "alerts/") + +# ------------------------------------------------------------------ +# Redshift +# ------------------------------------------------------------------ + +REDSHIFT_HOST: str = os.environ.get("REDSHIFT_HOST", "") +REDSHIFT_PORT: int = int(os.environ.get("REDSHIFT_PORT", "5439")) +REDSHIFT_DATABASE: str = os.environ.get("REDSHIFT_DATABASE", "drone_analytics") +REDSHIFT_USER: str = os.environ.get("REDSHIFT_USER", "") +REDSHIFT_PASSWORD: str = os.environ.get("REDSHIFT_PASSWORD", "") +REDSHIFT_IAM_ROLE: str = os.environ.get("REDSHIFT_IAM_ROLE", "") + +REDSHIFT_SCHEMA: str = os.environ.get("REDSHIFT_SCHEMA", "public") +REDSHIFT_TELEMETRY_TABLE: str = "drone_telemetry" +REDSHIFT_ROUTES_TABLE: str = "drone_routes" +REDSHIFT_ALERTS_TABLE: str = "collision_alerts" diff --git a/config/kafka_config.py b/config/kafka_config.py new file mode 100644 index 0000000..c16f2e3 --- /dev/null +++ b/config/kafka_config.py @@ -0,0 +1,46 @@ +"""Kafka configuration.""" + +import os + +# ------------------------------------------------------------------ +# Broker connection +# ------------------------------------------------------------------ + +KAFKA_BOOTSTRAP_SERVERS: str = os.environ.get( + "KAFKA_BOOTSTRAP_SERVERS", "localhost:9092" +) + +# ------------------------------------------------------------------ +# Topics +# ------------------------------------------------------------------ + +TELEMETRY_TOPIC: str = os.environ.get("KAFKA_TELEMETRY_TOPIC", "drone-telemetry") +PROCESSED_TOPIC: str = os.environ.get("KAFKA_PROCESSED_TOPIC", "drone-processed") +ALERTS_TOPIC: str = os.environ.get("KAFKA_ALERTS_TOPIC", "drone-alerts") + +# ------------------------------------------------------------------ +# Producer defaults +# ------------------------------------------------------------------ + +PRODUCER_CONFIG: dict = { + "bootstrap.servers": KAFKA_BOOTSTRAP_SERVERS, + "client.id": os.environ.get("KAFKA_CLIENT_ID", "drone-producer"), + "acks": "all", + "retries": 3, + "batch.size": 16384, + "linger.ms": 5, + "compression.type": "gzip", +} + +# ------------------------------------------------------------------ +# Consumer defaults +# ------------------------------------------------------------------ + +CONSUMER_CONFIG: dict = { + "bootstrap.servers": KAFKA_BOOTSTRAP_SERVERS, + "group.id": os.environ.get("KAFKA_CONSUMER_GROUP", "drone-pipeline"), + "auto.offset.reset": "earliest", + "enable.auto.commit": False, + "session.timeout.ms": 30000, + "max.poll.interval.ms": 300000, +} diff --git a/config/spark_config.py b/config/spark_config.py new file mode 100644 index 0000000..fdb0860 --- /dev/null +++ b/config/spark_config.py @@ -0,0 +1,59 @@ +"""Spark configuration.""" + +import os + +# ------------------------------------------------------------------ +# Application +# ------------------------------------------------------------------ + +SPARK_APP_NAME: str = os.environ.get("SPARK_APP_NAME", "DroneDeliveryPipeline") +SPARK_MASTER: str = os.environ.get("SPARK_MASTER", "local[*]") + +# ------------------------------------------------------------------ +# Streaming +# ------------------------------------------------------------------ + +STREAMING_TRIGGER_INTERVAL: str = os.environ.get( + "STREAMING_TRIGGER_INTERVAL", "10 seconds" +) +STREAMING_CHECKPOINT_LOCATION: str = os.environ.get( + "STREAMING_CHECKPOINT_LOCATION", "/tmp/drone-pipeline-checkpoints" +) +STREAMING_OUTPUT_MODE: str = os.environ.get("STREAMING_OUTPUT_MODE", "append") + +# ------------------------------------------------------------------ +# Collision detection +# ------------------------------------------------------------------ + +# Minimum safe horizontal separation in metres +COLLISION_SAFE_DISTANCE_M: float = float( + os.environ.get("COLLISION_SAFE_DISTANCE_M", "50.0") +) +# Minimum safe vertical separation in metres +COLLISION_SAFE_ALTITUDE_M: float = float( + os.environ.get("COLLISION_SAFE_ALTITUDE_M", "10.0") +) + +# ------------------------------------------------------------------ +# Route optimisation +# ------------------------------------------------------------------ + +LOW_BATTERY_THRESHOLD: float = float( + os.environ.get("LOW_BATTERY_THRESHOLD", "20.0") +) +MAX_PAYLOAD_WEIGHT_KG: float = float( + os.environ.get("MAX_PAYLOAD_WEIGHT_KG", "5.0") +) + +# ------------------------------------------------------------------ +# Spark / Kafka integration packages +# ------------------------------------------------------------------ + +SPARK_PACKAGES: str = os.environ.get( + "SPARK_PACKAGES", + ( + "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0," + "com.amazonaws:aws-java-sdk-bundle:1.12.262," + "org.apache.hadoop:hadoop-aws:3.3.4" + ), +) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..6a4b48e --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,126 @@ +version: "3.8" + +services: + # ----------------------------------------------------------------------- + # ZooKeeper + # ----------------------------------------------------------------------- + zookeeper: + image: confluentinc/cp-zookeeper:7.6.0 + container_name: zookeeper + environment: + ZOOKEEPER_CLIENT_PORT: 2181 + ZOOKEEPER_TICK_TIME: 2000 + ports: + - "2181:2181" + healthcheck: + test: ["CMD", "bash", "-c", "echo ruok | nc localhost 2181 | grep imok"] + interval: 10s + timeout: 5s + retries: 5 + + # ----------------------------------------------------------------------- + # Kafka Broker + # ----------------------------------------------------------------------- + kafka: + image: confluentinc/cp-kafka:7.6.0 + container_name: kafka + depends_on: + zookeeper: + condition: service_healthy + ports: + - "9092:9092" + - "29092:29092" + environment: + KAFKA_BROKER_ID: 1 + KAFKA_ZOOKEEPER_CONNECT: "zookeeper:2181" + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: PLAINTEXT:PLAINTEXT,PLAINTEXT_HOST:PLAINTEXT + KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:29092,PLAINTEXT_HOST://localhost:9092 + KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 + KAFKA_AUTO_CREATE_TOPICS_ENABLE: "true" + KAFKA_LOG_RETENTION_HOURS: 24 + KAFKA_NUM_PARTITIONS: 3 + healthcheck: + test: ["CMD", "kafka-broker-api-versions", "--bootstrap-server", "localhost:9092"] + interval: 15s + timeout: 10s + retries: 10 + + # ----------------------------------------------------------------------- + # Schema Registry (optional – useful for Avro schemas in the future) + # ----------------------------------------------------------------------- + schema-registry: + image: confluentinc/cp-schema-registry:7.6.0 + container_name: schema-registry + depends_on: + kafka: + condition: service_healthy + ports: + - "8081:8081" + environment: + SCHEMA_REGISTRY_HOST_NAME: schema-registry + SCHEMA_REGISTRY_KAFKASTORE_BOOTSTRAP_SERVERS: "kafka:29092" + SCHEMA_REGISTRY_LISTENERS: http://0.0.0.0:8081 + + # ----------------------------------------------------------------------- + # Kafka UI (Kowl / kafka-ui) + # ----------------------------------------------------------------------- + kafka-ui: + image: provectuslabs/kafka-ui:latest + container_name: kafka-ui + depends_on: + - kafka + ports: + - "8080:8080" + environment: + KAFKA_CLUSTERS_0_NAME: local + KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:29092 + KAFKA_CLUSTERS_0_SCHEMAREGISTRY: http://schema-registry:8081 + + # ----------------------------------------------------------------------- + # Drone Telemetry Producer + # ----------------------------------------------------------------------- + drone-producer: + build: + context: . + dockerfile: Dockerfile + container_name: drone-producer + depends_on: + kafka: + condition: service_healthy + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + KAFKA_TELEMETRY_TOPIC: drone-telemetry + KAFKA_PROCESSED_TOPIC: drone-processed + KAFKA_ALERTS_TOPIC: drone-alerts + command: python -m src.kafka.drone_producer + restart: on-failure + + # ----------------------------------------------------------------------- + # Spark Streaming Pipeline + # ----------------------------------------------------------------------- + spark-pipeline: + build: + context: . + dockerfile: Dockerfile + container_name: spark-pipeline + depends_on: + kafka: + condition: service_healthy + environment: + KAFKA_BOOTSTRAP_SERVERS: kafka:29092 + SPARK_MASTER: local[2] + STREAMING_CHECKPOINT_LOCATION: /tmp/checkpoints + # Set real AWS credentials via .env file or environment + AWS_REGION: ${AWS_REGION:-us-east-1} + AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-} + AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-} + S3_BUCKET: ${S3_BUCKET:-drone-telemetry-data} + REDSHIFT_HOST: ${REDSHIFT_HOST:-} + REDSHIFT_USER: ${REDSHIFT_USER:-} + REDSHIFT_PASSWORD: ${REDSHIFT_PASSWORD:-} + command: python -m scripts.run_pipeline + restart: on-failure + +networks: + default: + name: drone-pipeline-network diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8bdfa89 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +# ---- Kafka ---- +confluent-kafka==2.3.0 + +# ---- Spark ---- +pyspark==3.5.0 + +# ---- AWS ---- +boto3==1.34.0 +psycopg2-binary==2.9.9 + +# ---- Testing ---- +pytest==7.4.4 +pytest-cov==4.1.0 + +# ---- Utilities ---- +python-dotenv==1.0.0 diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py new file mode 100644 index 0000000..4349588 --- /dev/null +++ b/scripts/run_pipeline.py @@ -0,0 +1,80 @@ +""" +Pipeline entrypoint. + +Starts Spark Structured Streaming to: + 1. Read raw telemetry from Kafka + 2. Optimise routes + 3. Detect collisions and publish alerts to Kafka + 4. Persist enriched records to S3 (Parquet) +""" + +import logging +import sys + +from config import aws_config, kafka_config, spark_config +from src.spark.collision_detector import CollisionDetector +from src.spark.route_optimizer import RouteOptimizer +from src.spark.telemetry_processor import TelemetryProcessor, create_spark_session + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + + +def main() -> None: + logger.info("Starting Drone Delivery Pipeline …") + spark = create_spark_session() + spark.sparkContext.setLogLevel("WARN") + + processor = TelemetryProcessor(spark) + optimizer = RouteOptimizer() + detector = CollisionDetector() + + # Step 1: parse raw telemetry stream + telemetry_stream = processor.read_stream() + + # Step 2: add route-optimisation columns + optimised_stream = optimizer.optimise_stream(telemetry_stream) + + # Step 3: write enriched data to S3 (Parquet) + s3_path = f"s3a://{aws_config.S3_BUCKET}/{aws_config.S3_PROCESSED_PREFIX}" + s3_query = processor.write_to_s3( + optimised_stream, + s3_path=s3_path, + checkpoint_location=spark_config.STREAMING_CHECKPOINT_LOCATION + "/s3-sink", + ) + + # Step 4: forward to processed Kafka topic + kafka_query = processor.write_to_kafka( + optimised_stream, + topic=kafka_config.PROCESSED_TOPIC, + checkpoint_location=spark_config.STREAMING_CHECKPOINT_LOCATION + "/kafka-sink", + ) + + # Step 5: collision detection (foreachBatch on raw stream) + alert_query = detector.detect_stream( + telemetry_stream, + checkpoint_location=spark_config.STREAMING_CHECKPOINT_LOCATION + "/alerts", + kafka_topic=kafka_config.ALERTS_TOPIC, + ) + + logger.info("All streaming queries started.") + logger.info(" S3 sink: %s", s3_query.id) + logger.info(" Kafka sink: %s", kafka_query.id) + logger.info(" Alert query: %s", alert_query.id) + + try: + spark.streams.awaitAnyTermination() + except KeyboardInterrupt: + logger.info("Pipeline interrupted by user.") + finally: + for q in spark.streams.active: + q.stop() + spark.stop() + logger.info("Pipeline stopped.") + + +if __name__ == "__main__": + main() diff --git a/scripts/start_pipeline.sh b/scripts/start_pipeline.sh new file mode 100755 index 0000000..ca34b30 --- /dev/null +++ b/scripts/start_pipeline.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# --------------------------------------------------------------------------- +# start_pipeline.sh +# +# Convenience wrapper to: +# 1. Start the Docker Compose services (Kafka, ZooKeeper, etc.) +# 2. Wait for Kafka to be ready +# 3. Create required Kafka topics +# --------------------------------------------------------------------------- + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(dirname "$SCRIPT_DIR")" + +# Load optional .env file +ENV_FILE="${ROOT_DIR}/.env" +if [[ -f "$ENV_FILE" ]]; then + # shellcheck disable=SC1090 + export $(grep -v '^#' "$ENV_FILE" | xargs) +fi + +BOOTSTRAP_SERVERS="${KAFKA_BOOTSTRAP_SERVERS:-localhost:9092}" +TELEMETRY_TOPIC="${KAFKA_TELEMETRY_TOPIC:-drone-telemetry}" +PROCESSED_TOPIC="${KAFKA_PROCESSED_TOPIC:-drone-processed}" +ALERTS_TOPIC="${KAFKA_ALERTS_TOPIC:-drone-alerts}" + +echo "==> Starting Docker Compose services …" +docker compose -f "${ROOT_DIR}/docker-compose.yml" up -d zookeeper kafka + +echo "==> Waiting for Kafka to be ready …" +until docker compose -f "${ROOT_DIR}/docker-compose.yml" exec kafka \ + kafka-broker-api-versions --bootstrap-server localhost:9092 > /dev/null 2>&1; do + echo " Kafka not ready yet, retrying in 5 s …" + sleep 5 +done +echo " Kafka is ready." + +echo "==> Creating Kafka topics …" +for TOPIC in "$TELEMETRY_TOPIC" "$PROCESSED_TOPIC" "$ALERTS_TOPIC"; do + docker compose -f "${ROOT_DIR}/docker-compose.yml" exec kafka \ + kafka-topics --bootstrap-server localhost:9092 \ + --create --if-not-exists \ + --topic "$TOPIC" \ + --partitions 3 \ + --replication-factor 1 + echo " Topic '$TOPIC' ready." +done + +echo "==> Starting all services …" +docker compose -f "${ROOT_DIR}/docker-compose.yml" up -d + +echo "" +echo "Pipeline services are running." +echo " Kafka UI: http://localhost:8080" +echo " Schema Registry: http://localhost:8081" +echo "" +echo "To stop the pipeline: docker compose down" diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aws/__init__.py b/src/aws/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aws/redshift_handler.py b/src/aws/redshift_handler.py new file mode 100644 index 0000000..e7f758c --- /dev/null +++ b/src/aws/redshift_handler.py @@ -0,0 +1,271 @@ +"""AWS Redshift handler for drone analytics.""" + +import logging +from contextlib import contextmanager +from typing import Any, Dict, Generator, List, Optional, Tuple + +import psycopg2 +from psycopg2.extensions import connection as PgConnection +from psycopg2.extras import execute_values + +from config import aws_config + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# DDL statements +# --------------------------------------------------------------------------- + +CREATE_TELEMETRY_TABLE = f""" +CREATE TABLE IF NOT EXISTS {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_TELEMETRY_TABLE} ( + id BIGINT IDENTITY(1,1) PRIMARY KEY, + drone_id VARCHAR(64) NOT NULL, + latitude DOUBLE PRECISION NOT NULL, + longitude DOUBLE PRECISION NOT NULL, + altitude REAL, + speed REAL, + heading REAL, + battery_level REAL, + status VARCHAR(32), + event_time TIMESTAMP, + payload_weight REAL, + ingested_at TIMESTAMP DEFAULT GETDATE() +) +DISTKEY(drone_id) +SORTKEY(event_time); +""" + +CREATE_ALERTS_TABLE = f""" +CREATE TABLE IF NOT EXISTS {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_ALERTS_TABLE} ( + id BIGINT IDENTITY(1,1) PRIMARY KEY, + drone_a VARCHAR(64) NOT NULL, + drone_b VARCHAR(64) NOT NULL, + horizontal_distance_m REAL, + vertical_distance_m REAL, + is_collision_risk BOOLEAN, + alert_level VARCHAR(16), + detected_at TIMESTAMP DEFAULT GETDATE() +) +DISTKEY(drone_a) +SORTKEY(detected_at); +""" + +CREATE_ROUTES_TABLE = f""" +CREATE TABLE IF NOT EXISTS {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_ROUTES_TABLE} ( + id BIGINT IDENTITY(1,1) PRIMARY KEY, + drone_id VARCHAR(64) NOT NULL, + distance_to_dest_m REAL, + estimated_flight_time_s REAL, + recommended_action VARCHAR(32), + optimisation_score REAL, + recorded_at TIMESTAMP DEFAULT GETDATE() +) +DISTKEY(drone_id) +SORTKEY(recorded_at); +""" + + +# --------------------------------------------------------------------------- +# RedshiftHandler +# --------------------------------------------------------------------------- + +class RedshiftHandler: + """ + Manages a connection to Amazon Redshift and provides helpers for + inserting drone telemetry, routes, and collision-alert data. + """ + + def __init__( + self, + host: str = aws_config.REDSHIFT_HOST, + port: int = aws_config.REDSHIFT_PORT, + database: str = aws_config.REDSHIFT_DATABASE, + user: str = aws_config.REDSHIFT_USER, + password: str = aws_config.REDSHIFT_PASSWORD, + ) -> None: + self._dsn: Dict[str, Any] = { + "host": host, + "port": port, + "dbname": database, + "user": user, + "password": password, + } + self._conn: Optional[PgConnection] = None + + # ------------------------------------------------------------------ + # Connection management + # ------------------------------------------------------------------ + + def connect(self) -> None: + """Open a connection to Redshift.""" + self._conn = psycopg2.connect(**self._dsn) + logger.info("Connected to Redshift at %s:%s/%s", self._dsn["host"], self._dsn["port"], self._dsn["dbname"]) + + def disconnect(self) -> None: + """Close the Redshift connection.""" + if self._conn and not self._conn.closed: + self._conn.close() + logger.info("Disconnected from Redshift.") + + @contextmanager + def cursor(self) -> Generator: + """Context manager that yields a cursor and commits on success.""" + if self._conn is None or self._conn.closed: + self.connect() + cur = self._conn.cursor() + try: + yield cur + self._conn.commit() + except Exception: + self._conn.rollback() + raise + finally: + cur.close() + + # ------------------------------------------------------------------ + # Schema initialisation + # ------------------------------------------------------------------ + + def initialise_schema(self) -> None: + """Create tables if they do not already exist.""" + with self.cursor() as cur: + for ddl in (CREATE_TELEMETRY_TABLE, CREATE_ALERTS_TABLE, CREATE_ROUTES_TABLE): + cur.execute(ddl) + logger.info("Redshift schema initialised.") + + # ------------------------------------------------------------------ + # Bulk inserts + # ------------------------------------------------------------------ + + def insert_telemetry_batch(self, records: List[dict]) -> int: + """ + Insert a batch of telemetry dictionaries into the telemetry table. + Returns the number of rows inserted. + """ + if not records: + return 0 + rows: List[Tuple] = [ + ( + r["drone_id"], + r["latitude"], + r["longitude"], + r.get("altitude"), + r.get("speed"), + r.get("heading"), + r.get("battery_level"), + r.get("status"), + r.get("event_time"), + r.get("payload_weight"), + ) + for r in records + ] + sql = f""" + INSERT INTO {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_TELEMETRY_TABLE} + (drone_id, latitude, longitude, altitude, speed, heading, + battery_level, status, event_time, payload_weight) + VALUES %s + """ + with self.cursor() as cur: + execute_values(cur, sql, rows) + logger.info("Inserted %d telemetry rows into Redshift.", len(rows)) + return len(rows) + + def insert_alerts_batch(self, alerts: List[dict]) -> int: + """Insert collision-alert records into the alerts table.""" + if not alerts: + return 0 + rows = [ + ( + a["drone_a"], + a["drone_b"], + a.get("horizontal_distance_m"), + a.get("vertical_distance_m"), + a.get("is_collision_risk", False), + a.get("alert_level", "WARNING"), + ) + for a in alerts + ] + sql = f""" + INSERT INTO {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_ALERTS_TABLE} + (drone_a, drone_b, horizontal_distance_m, vertical_distance_m, is_collision_risk, alert_level) + VALUES %s + """ + with self.cursor() as cur: + execute_values(cur, sql, rows) + logger.info("Inserted %d alert rows into Redshift.", len(rows)) + return len(rows) + + def insert_routes_batch(self, routes: List[dict]) -> int: + """Insert route-optimisation records.""" + if not routes: + return 0 + rows = [ + ( + r["drone_id"], + r.get("distance_to_dest_m"), + r.get("estimated_flight_time_s"), + r.get("recommended_action"), + r.get("optimisation_score"), + ) + for r in routes + ] + sql = f""" + INSERT INTO {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_ROUTES_TABLE} + (drone_id, distance_to_dest_m, estimated_flight_time_s, recommended_action, optimisation_score) + VALUES %s + """ + with self.cursor() as cur: + execute_values(cur, sql, rows) + logger.info("Inserted %d route rows into Redshift.", len(rows)) + return len(rows) + + # ------------------------------------------------------------------ + # Analytics queries + # ------------------------------------------------------------------ + + def get_low_battery_drones(self, threshold: float = 20.0) -> List[dict]: + """Return drones whose latest battery reading is below *threshold*.""" + sql = f""" + SELECT drone_id, battery_level, status, event_time + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY drone_id ORDER BY event_time DESC) AS rn + FROM {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_TELEMETRY_TABLE} + ) sub + WHERE rn = 1 AND battery_level < %s + ORDER BY battery_level ASC; + """ + with self.cursor() as cur: + cur.execute(sql, (threshold,)) + cols = [desc[0] for desc in cur.description] + return [dict(zip(cols, row)) for row in cur.fetchall()] + + def get_recent_alerts(self, limit: int = 100) -> List[dict]: + """Return the most recent *limit* collision alerts.""" + sql = f""" + SELECT * FROM {aws_config.REDSHIFT_SCHEMA}.{aws_config.REDSHIFT_ALERTS_TABLE} + ORDER BY detected_at DESC + LIMIT %s; + """ + with self.cursor() as cur: + cur.execute(sql, (limit,)) + cols = [desc[0] for desc in cur.description] + return [dict(zip(cols, row)) for row in cur.fetchall()] + + def copy_from_s3(self, s3_path: str, table: str, iam_role: str = aws_config.REDSHIFT_IAM_ROLE) -> None: + """ + Use the Redshift COPY command to bulk-load Parquet data from S3. + + Args: + s3_path: Full S3 URI, e.g. ``s3://bucket/prefix/``. + table: Fully-qualified target table name. + iam_role: IAM role ARN with read access to S3. + """ + sql = f""" + COPY {table} + FROM '{s3_path}' + IAM_ROLE '{iam_role}' + FORMAT AS PARQUET; + """ + with self.cursor() as cur: + cur.execute(sql) + logger.info("COPY from %s into %s completed.", s3_path, table) diff --git a/src/aws/s3_handler.py b/src/aws/s3_handler.py new file mode 100644 index 0000000..7cc4058 --- /dev/null +++ b/src/aws/s3_handler.py @@ -0,0 +1,121 @@ +"""AWS S3 handler for drone telemetry data storage.""" + +import io +import json +import logging +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError + +from config import aws_config + +logger = logging.getLogger(__name__) + + +class S3Handler: + """ + Handles reading and writing drone telemetry data to AWS S3. + + Data is stored as newline-delimited JSON (NDJSON) or Parquet via + partitioned prefixes: ``year=YYYY/month=MM/day=DD/``. + """ + + def __init__( + self, + bucket: str = aws_config.S3_BUCKET, + region: str = aws_config.AWS_REGION, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + ) -> None: + self._bucket = bucket + session_kwargs: Dict[str, Any] = {"region_name": region} + key_id = aws_access_key_id or aws_config.AWS_ACCESS_KEY_ID + secret = aws_secret_access_key or aws_config.AWS_SECRET_ACCESS_KEY + if key_id and secret: + session_kwargs["aws_access_key_id"] = key_id + session_kwargs["aws_secret_access_key"] = secret + self._s3 = boto3.client("s3", **session_kwargs) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def upload_json_records( + self, + records: List[dict], + prefix: str = aws_config.S3_RAW_PREFIX, + partition_by_date: bool = True, + ) -> str: + """ + Serialise *records* as NDJSON and upload to S3. + + Returns the full S3 key of the uploaded object. + """ + key = self._build_key(prefix, "telemetry.ndjson", partition_by_date) + body = "\n".join(json.dumps(r) for r in records).encode("utf-8") + self._put_object(key, body, content_type="application/x-ndjson") + logger.info("Uploaded %d records to s3://%s/%s", len(records), self._bucket, key) + return key + + def upload_file(self, local_path: str, s3_key: str) -> None: + """Upload a local file to *s3_key* within the configured bucket.""" + try: + self._s3.upload_file(local_path, self._bucket, s3_key) + logger.info("Uploaded %s -> s3://%s/%s", local_path, self._bucket, s3_key) + except (BotoCoreError, ClientError) as exc: + logger.error("Failed to upload %s: %s", local_path, exc) + raise + + def download_json_records(self, s3_key: str) -> List[dict]: + """Download and parse NDJSON from *s3_key*.""" + try: + response = self._s3.get_object(Bucket=self._bucket, Key=s3_key) + body = response["Body"].read().decode("utf-8") + return [json.loads(line) for line in body.splitlines() if line.strip()] + except (BotoCoreError, ClientError) as exc: + logger.error("Failed to download s3://%s/%s: %s", self._bucket, s3_key, exc) + raise + + def list_objects(self, prefix: str) -> List[str]: + """Return a list of S3 keys under *prefix*.""" + keys: List[str] = [] + paginator = self._s3.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=self._bucket, Prefix=prefix): + for obj in page.get("Contents", []): + keys.append(obj["Key"]) + return keys + + def delete_object(self, s3_key: str) -> None: + """Delete a single object at *s3_key*.""" + try: + self._s3.delete_object(Bucket=self._bucket, Key=s3_key) + logger.info("Deleted s3://%s/%s", self._bucket, s3_key) + except (BotoCoreError, ClientError) as exc: + logger.error("Failed to delete s3://%s/%s: %s", self._bucket, s3_key, exc) + raise + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _put_object(self, key: str, body: bytes, content_type: str = "application/octet-stream") -> None: + try: + self._s3.put_object( + Bucket=self._bucket, + Key=key, + Body=body, + ContentType=content_type, + ) + except (BotoCoreError, ClientError) as exc: + logger.error("put_object failed for key %s: %s", key, exc) + raise + + @staticmethod + def _build_key(prefix: str, filename: str, partition_by_date: bool) -> str: + if partition_by_date: + now = datetime.now(tz=timezone.utc) + partition = f"year={now.year:04d}/month={now.month:02d}/day={now.day:02d}/" + return f"{prefix}{partition}{filename}" + return f"{prefix}{filename}" diff --git a/src/kafka/__init__.py b/src/kafka/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/kafka/drone_consumer.py b/src/kafka/drone_consumer.py new file mode 100644 index 0000000..5b3a5c7 --- /dev/null +++ b/src/kafka/drone_consumer.py @@ -0,0 +1,120 @@ +"""Kafka consumer that reads processed drone telemetry messages.""" + +import logging +from typing import Callable, List, Optional + +from confluent_kafka import Consumer, KafkaError, KafkaException, Message + +from config import kafka_config +from src.models.drone_telemetry import DroneTelemetry + +logger = logging.getLogger(__name__) + + +class DroneConsumer: + """ + Consumes drone telemetry messages from one or more Kafka topics. + + Usage:: + + consumer = DroneConsumer(topics=[kafka_config.TELEMETRY_TOPIC]) + consumer.consume(handler=my_handler) + """ + + def __init__( + self, + topics: Optional[List[str]] = None, + config: Optional[dict] = None, + ) -> None: + self._topics = topics or [kafka_config.TELEMETRY_TOPIC] + self._config = config or kafka_config.CONSUMER_CONFIG + self._consumer = Consumer(self._config) + self._running = False + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def consume( + self, + handler: Callable[[DroneTelemetry], None], + poll_timeout: float = 1.0, + max_messages: Optional[int] = None, + ) -> None: + """ + Start consuming messages and invoke *handler* for each valid telemetry record. + + Args: + handler: Callable that receives a :class:`DroneTelemetry` instance. + poll_timeout: Seconds to wait in each ``poll()`` call. + max_messages: Stop after processing this many messages (``None`` = run forever). + """ + self._consumer.subscribe(self._topics) + self._running = True + count = 0 + try: + while self._running and (max_messages is None or count < max_messages): + msg: Optional[Message] = self._consumer.poll(poll_timeout) + if msg is None: + continue + if msg.error(): + self._handle_error(msg) + continue + telemetry = self._deserialise(msg) + if telemetry is not None: + handler(telemetry) + self._consumer.commit(message=msg, asynchronous=False) + count += 1 + except KeyboardInterrupt: + logger.info("Consumer stopped by user.") + finally: + self._consumer.close() + + def stop(self) -> None: + """Signal the consume loop to stop.""" + self._running = False + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _handle_error(msg: Message) -> None: + err = msg.error() + if err.code() == KafkaError._PARTITION_EOF: + logger.debug( + "End of partition reached: %s [%d] @ %d", + msg.topic(), + msg.partition(), + msg.offset(), + ) + else: + raise KafkaException(err) + + @staticmethod + def _deserialise(msg: Message) -> Optional[DroneTelemetry]: + try: + return DroneTelemetry.from_json(msg.value().decode("utf-8")) + except Exception as exc: # pylint: disable=broad-except + logger.warning("Failed to deserialise message: %s", exc) + return None + + +# --------------------------------------------------------------------------- +# Entrypoint +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + def _print_handler(telemetry: DroneTelemetry) -> None: + logger.info( + "[%s] alt=%.1fm battery=%.1f%% status=%s", + telemetry.drone_id, + telemetry.altitude, + telemetry.battery_level, + telemetry.status, + ) + + consumer = DroneConsumer(topics=[kafka_config.TELEMETRY_TOPIC]) + consumer.consume(handler=_print_handler) diff --git a/src/kafka/drone_producer.py b/src/kafka/drone_producer.py new file mode 100644 index 0000000..6842314 --- /dev/null +++ b/src/kafka/drone_producer.py @@ -0,0 +1,182 @@ +"""Kafka producer that simulates real-time drone telemetry data.""" + +import json +import logging +import random +import time +from typing import Callable, List, Optional + +from confluent_kafka import Producer +from confluent_kafka.admin import AdminClient, NewTopic + +from config import kafka_config +from src.models.drone_telemetry import DroneTelemetry + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helper: create topics if they do not exist +# --------------------------------------------------------------------------- + +def ensure_topics(topics: List[str], num_partitions: int = 3, replication_factor: int = 1) -> None: + """Create Kafka topics if they do not already exist.""" + admin = AdminClient({"bootstrap.servers": kafka_config.KAFKA_BOOTSTRAP_SERVERS}) + existing = set(admin.list_topics(timeout=10).topics.keys()) + new_topics = [ + NewTopic(t, num_partitions=num_partitions, replication_factor=replication_factor) + for t in topics + if t not in existing + ] + if new_topics: + futures = admin.create_topics(new_topics) + for topic, future in futures.items(): + try: + future.result() + logger.info("Created topic: %s", topic) + except Exception as exc: # pylint: disable=broad-except + logger.warning("Topic %s may already exist: %s", topic, exc) + + +# --------------------------------------------------------------------------- +# Telemetry simulator +# --------------------------------------------------------------------------- + +def _simulate_telemetry(drone_id: str, prev: Optional[DroneTelemetry] = None) -> DroneTelemetry: + """Generate a plausible next telemetry reading for a drone.""" + if prev is None: + return DroneTelemetry( + drone_id=drone_id, + latitude=random.uniform(37.7, 37.8), + longitude=random.uniform(-122.5, -122.4), + altitude=random.uniform(50.0, 120.0), + speed=random.uniform(5.0, 15.0), + heading=random.uniform(0.0, 360.0), + battery_level=random.uniform(60.0, 100.0), + status=random.choice(["en_route", "hovering"]), + destination_lat=random.uniform(37.7, 37.8), + destination_lon=random.uniform(-122.5, -122.4), + payload_weight=random.uniform(0.1, 5.0), + ) + + # Drift existing values slightly + lat = prev.latitude + random.uniform(-0.001, 0.001) + lon = prev.longitude + random.uniform(-0.001, 0.001) + alt = max(0.0, prev.altitude + random.uniform(-2.0, 2.0)) + speed = max(0.0, prev.speed + random.uniform(-1.0, 1.0)) + heading = (prev.heading + random.uniform(-5.0, 5.0)) % 360 + battery = max(0.0, prev.battery_level - random.uniform(0.1, 0.5)) + status = prev.status + if battery < 20.0: + status = "returning" + elif battery < 5.0: + status = "emergency" + + return DroneTelemetry( + drone_id=drone_id, + latitude=lat, + longitude=lon, + altitude=alt, + speed=speed, + heading=heading, + battery_level=battery, + status=status, + destination_lat=prev.destination_lat, + destination_lon=prev.destination_lon, + payload_weight=prev.payload_weight, + ) + + +# --------------------------------------------------------------------------- +# Delivery callback +# --------------------------------------------------------------------------- + +def _delivery_report(err, msg) -> None: + if err is not None: + logger.error("Message delivery failed for %s: %s", msg.key(), err) + else: + logger.debug( + "Delivered %s [partition %d] @ offset %d", + msg.topic(), + msg.partition(), + msg.offset(), + ) + + +# --------------------------------------------------------------------------- +# Producer class +# --------------------------------------------------------------------------- + +class DroneProducer: + """Publishes drone telemetry messages to a Kafka topic.""" + + def __init__( + self, + config: Optional[dict] = None, + topic: str = kafka_config.TELEMETRY_TOPIC, + delivery_callback: Optional[Callable] = None, + ) -> None: + self._config = config or kafka_config.PRODUCER_CONFIG + self._topic = topic + self._delivery_callback = delivery_callback or _delivery_report + self._producer = Producer(self._config) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def send(self, telemetry: DroneTelemetry) -> None: + """Serialise *telemetry* and publish it to Kafka.""" + self._producer.produce( + topic=self._topic, + key=telemetry.drone_id.encode("utf-8"), + value=telemetry.to_json().encode("utf-8"), + callback=self._delivery_callback, + ) + self._producer.poll(0) + + def flush(self, timeout: float = 10.0) -> int: + """Flush outstanding messages. Returns remaining message count.""" + return self._producer.flush(timeout) + + def simulate( + self, + drone_ids: List[str], + interval_seconds: float = 1.0, + max_messages: Optional[int] = None, + ) -> None: + """ + Continuously publish simulated telemetry for *drone_ids*. + + Args: + drone_ids: List of drone identifier strings. + interval_seconds: Sleep duration between publishing rounds. + max_messages: Stop after this many total messages (``None`` = run forever). + """ + state: dict[str, DroneTelemetry] = {} + count = 0 + try: + while max_messages is None or count < max_messages: + for drone_id in drone_ids: + telemetry = _simulate_telemetry(drone_id, state.get(drone_id)) + state[drone_id] = telemetry + self.send(telemetry) + count += 1 + logger.info("Published telemetry for %s (battery=%.1f%%)", drone_id, telemetry.battery_level) + time.sleep(interval_seconds) + except KeyboardInterrupt: + logger.info("Simulation stopped by user.") + finally: + self.flush() + + +# --------------------------------------------------------------------------- +# Entrypoint +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + drone_ids = [f"DRONE-{i:03d}" for i in range(1, 6)] + ensure_topics([kafka_config.TELEMETRY_TOPIC, kafka_config.PROCESSED_TOPIC, kafka_config.ALERTS_TOPIC]) + producer = DroneProducer() + producer.simulate(drone_ids, interval_seconds=1.0) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/drone_telemetry.py b/src/models/drone_telemetry.py new file mode 100644 index 0000000..bb522ce --- /dev/null +++ b/src/models/drone_telemetry.py @@ -0,0 +1,52 @@ +"""Drone telemetry data model.""" + +import json +import time +from dataclasses import asdict, dataclass, field +from typing import Optional + + +@dataclass +class DroneTelemetry: + """Represents a single telemetry snapshot from a drone.""" + + drone_id: str + latitude: float + longitude: float + altitude: float # metres above sea level + speed: float # m/s + heading: float # degrees (0–360) + battery_level: float # percentage 0–100 + status: str # 'en_route', 'hovering', 'returning', 'landing', 'emergency' + timestamp: float = field(default_factory=time.time) + destination_lat: Optional[float] = None + destination_lon: Optional[float] = None + payload_weight: float = 0.0 # kg + + # ------------------------------------------------------------------ + # Serialisation helpers + # ------------------------------------------------------------------ + + def to_dict(self) -> dict: + return asdict(self) + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @classmethod + def from_dict(cls, data: dict) -> "DroneTelemetry": + return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}) + + @classmethod + def from_json(cls, json_str: str) -> "DroneTelemetry": + return cls.from_dict(json.loads(json_str)) + + # ------------------------------------------------------------------ + # Convenience + # ------------------------------------------------------------------ + + def is_low_battery(self, threshold: float = 20.0) -> bool: + return self.battery_level < threshold + + def is_emergency(self) -> bool: + return self.status == "emergency" diff --git a/src/spark/__init__.py b/src/spark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/spark/collision_detector.py b/src/spark/collision_detector.py new file mode 100644 index 0000000..f5a7201 --- /dev/null +++ b/src/spark/collision_detector.py @@ -0,0 +1,167 @@ +""" +Collision detector for drone swarms using Spark. + +For every micro-batch of telemetry, performs a self-join on the +DataFrame to find pairs of drones whose horizontal (and vertical) +separation falls below configurable safety thresholds. +""" + +import logging +import math + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from pyspark.sql.types import BooleanType, FloatType, StringType, StructField, StructType + +from config import kafka_config, spark_config + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# UDF helpers +# --------------------------------------------------------------------------- + +def _haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + """Return horizontal distance (metres) between two coordinate pairs.""" + if None in (lat1, lon1, lat2, lon2): + return float("inf") + R = 6_371_000.0 + phi1, phi2 = math.radians(lat1), math.radians(lat2) + dphi = math.radians(lat2 - lat1) + dlambda = math.radians(lon2 - lon1) + a = math.sin(dphi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2) ** 2 + return 2 * R * math.asin(math.sqrt(a)) + + +_haversine_udf = F.udf(_haversine, FloatType()) + +# Alert schema (used when writing to Kafka) +ALERT_SCHEMA = StructType( + [ + StructField("drone_a", StringType(), False), + StructField("drone_b", StringType(), False), + StructField("horizontal_distance_m", FloatType(), True), + StructField("vertical_distance_m", FloatType(), True), + StructField("is_collision_risk", BooleanType(), False), + StructField("alert_level", StringType(), False), + ] +) + + +# --------------------------------------------------------------------------- +# CollisionDetector +# --------------------------------------------------------------------------- + +class CollisionDetector: + """ + Detects potential collisions between drones in a telemetry DataFrame. + + Two drones are flagged when: + - Horizontal separation < ``safe_distance_m`` + - Vertical separation < ``safe_altitude_m`` + + Alert levels: + - ``WARNING``: one threshold breached. + - ``CRITICAL``: both thresholds breached simultaneously. + """ + + def __init__( + self, + safe_distance_m: float = spark_config.COLLISION_SAFE_DISTANCE_M, + safe_altitude_m: float = spark_config.COLLISION_SAFE_ALTITUDE_M, + ) -> None: + self._safe_distance_m = safe_distance_m + self._safe_altitude_m = safe_altitude_m + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def detect(self, df: DataFrame) -> DataFrame: + """ + Return a DataFrame of collision-risk pairs from a *batch* DataFrame. + + Each row describes a pair (drone_a, drone_b) with separation metrics. + Only pairs where drone_a < drone_b (lexicographic) are emitted to + avoid duplicates. + """ + a = df.alias("a") + b = df.alias("b") + + pairs = a.join(b, F.col("a.drone_id") < F.col("b.drone_id")) + + pairs = pairs.withColumn( + "horizontal_distance_m", + _haversine_udf( + F.col("a.latitude"), + F.col("a.longitude"), + F.col("b.latitude"), + F.col("b.longitude"), + ), + ).withColumn( + "vertical_distance_m", + F.abs(F.col("a.altitude") - F.col("b.altitude")).cast(FloatType()), + ) + + too_close_h = F.col("horizontal_distance_m") < self._safe_distance_m + too_close_v = F.col("vertical_distance_m") < self._safe_altitude_m + + pairs = pairs.withColumn( + "is_collision_risk", + too_close_h & too_close_v, + ).withColumn( + "alert_level", + F.when(too_close_h & too_close_v, F.lit("CRITICAL")) + .when(too_close_h | too_close_v, F.lit("WARNING")) + .otherwise(F.lit("OK")), + ) + + # Only raise an alert when drones are horizontally close. + # Vertical proximity alone does not indicate a collision risk. + alerts = pairs.filter(too_close_h).select( + F.col("a.drone_id").alias("drone_a"), + F.col("b.drone_id").alias("drone_b"), + "horizontal_distance_m", + "vertical_distance_m", + "is_collision_risk", + "alert_level", + ) + return alerts + + def detect_stream( + self, + df: DataFrame, + checkpoint_location: str = spark_config.STREAMING_CHECKPOINT_LOCATION + "/alerts", + kafka_topic: str = kafka_config.ALERTS_TOPIC, + ): + """ + Apply collision detection to a *streaming* DataFrame using foreachBatch. + + Detected alerts are published to the Kafka alerts topic. + """ + + def _process_batch(batch_df: DataFrame, _epoch_id: int) -> None: + if batch_df.rdd.isEmpty(): + return + alerts = self.detect(batch_df) + if alerts.rdd.isEmpty(): + return + alert_count = alerts.count() + logger.warning("Epoch %d: %d collision alert(s) detected.", _epoch_id, alert_count) + + # Write to Kafka alerts topic + alerts.select( + F.col("drone_a").alias("key"), + F.to_json(F.struct("*")).alias("value"), + ).write.format("kafka").option( + "kafka.bootstrap.servers", kafka_config.KAFKA_BOOTSTRAP_SERVERS + ).option( + "topic", kafka_topic + ).save() + + return ( + df.writeStream.foreachBatch(_process_batch) + .option("checkpointLocation", checkpoint_location) + .outputMode("update") + .start() + ) diff --git a/src/spark/route_optimizer.py b/src/spark/route_optimizer.py new file mode 100644 index 0000000..e957cc4 --- /dev/null +++ b/src/spark/route_optimizer.py @@ -0,0 +1,158 @@ +""" +Spark-based route optimiser for drone delivery. + +Analyses a streaming (or batch) DataFrame of drone telemetry and +recommends route adjustments based on battery level, payload weight, +distance to destination, and current heading. +""" + +import logging +import math + +from pyspark.sql import DataFrame +from pyspark.sql import functions as F +from pyspark.sql.types import FloatType, StringType + +from config import spark_config + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Pure-Python helper (used as a UDF) +# --------------------------------------------------------------------------- + +def _haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + """Return great-circle distance (metres) between two WGS-84 coordinates.""" + if None in (lat1, lon1, lat2, lon2): + return -1.0 + R = 6_371_000.0 # Earth radius in metres + phi1, phi2 = math.radians(lat1), math.radians(lat2) + dphi = math.radians(lat2 - lat1) + dlambda = math.radians(lon2 - lon1) + a = math.sin(dphi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2) ** 2 + return 2 * R * math.asin(math.sqrt(a)) + + +def _recommend_action( + battery_level: float, + distance_to_dest: float, + status: str, + payload_weight: float, +) -> str: + """ + Derive a recommended action string from current drone state. + + Returns one of: 'continue', 'return_to_base', 'emergency_land', + 'reduce_speed', 'optimal'. + """ + if battery_level is None or status is None: + return "unknown" + if battery_level < 5.0 or status == "emergency": + return "emergency_land" + if battery_level < spark_config.LOW_BATTERY_THRESHOLD: + # Check whether we can still reach the destination + if distance_to_dest > 0 and battery_level / 100.0 * 8000 < distance_to_dest: + return "return_to_base" + return "return_to_base" + if payload_weight is not None and payload_weight > spark_config.MAX_PAYLOAD_WEIGHT_KG: + return "reduce_speed" + if distance_to_dest > 0: + return "continue" + return "optimal" + + +# --------------------------------------------------------------------------- +# UDF registrations +# --------------------------------------------------------------------------- + +_haversine_udf = F.udf(_haversine_distance, FloatType()) +_action_udf = F.udf(_recommend_action, StringType()) + + +# --------------------------------------------------------------------------- +# RouteOptimizer +# --------------------------------------------------------------------------- + +class RouteOptimizer: + """ + Adds route-optimisation columns to a telemetry DataFrame. + + Columns added: + - ``distance_to_dest_m``: metres remaining to destination. + - ``estimated_flight_time_s``: seconds to destination at current speed. + - ``recommended_action``: suggested action string. + - ``optimisation_score``: 0–100 efficiency score (higher is better). + """ + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def optimise(self, df: DataFrame) -> DataFrame: + """Return *df* enriched with route-optimisation columns.""" + df = self._add_distance(df) + df = self._add_flight_time(df) + df = self._add_recommended_action(df) + df = self._add_score(df) + return df + + def optimise_stream(self, df: DataFrame) -> DataFrame: + """Same as :meth:`optimise` but intended for streaming DataFrames.""" + return self.optimise(df) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _add_distance(df: DataFrame) -> DataFrame: + return df.withColumn( + "distance_to_dest_m", + _haversine_udf( + F.col("latitude"), + F.col("longitude"), + F.col("destination_lat"), + F.col("destination_lon"), + ), + ) + + @staticmethod + def _add_flight_time(df: DataFrame) -> DataFrame: + return df.withColumn( + "estimated_flight_time_s", + F.when( + (F.col("speed") > 0) & (F.col("distance_to_dest_m") >= 0), + F.col("distance_to_dest_m") / F.col("speed"), + ).otherwise(F.lit(-1.0).cast(FloatType())), + ) + + @staticmethod + def _add_recommended_action(df: DataFrame) -> DataFrame: + return df.withColumn( + "recommended_action", + _action_udf( + F.col("battery_level"), + F.col("distance_to_dest_m"), + F.col("status"), + F.col("payload_weight"), + ), + ) + + @staticmethod + def _add_score(df: DataFrame) -> DataFrame: + """ + Compute a simple 0–100 efficiency score: + score = 0.5 * battery_level + 0.3 * speed_factor + 0.2 * payload_factor + where speed_factor and payload_factor are normalised 0–1 values. + """ + max_speed = 20.0 # m/s + speed_factor = F.least(F.col("speed") / max_speed, F.lit(1.0)) + payload_factor = F.lit(1.0) - ( + F.col("payload_weight") / spark_config.MAX_PAYLOAD_WEIGHT_KG + ).cast(FloatType()) + score = ( + F.lit(0.5) * F.col("battery_level") + + F.lit(0.3) * speed_factor * F.lit(100.0) + + F.lit(0.2) * F.greatest(payload_factor, F.lit(0.0)) * F.lit(100.0) + ) + return df.withColumn("optimisation_score", score.cast(FloatType())) diff --git a/src/spark/telemetry_processor.py b/src/spark/telemetry_processor.py new file mode 100644 index 0000000..95e8353 --- /dev/null +++ b/src/spark/telemetry_processor.py @@ -0,0 +1,170 @@ +""" +Spark Structured Streaming processor for drone telemetry. + +Reads raw JSON messages from Kafka, applies a schema, and writes the +structured stream to S3 (Parquet) and a Kafka processed topic. +""" + +import logging +from typing import Optional + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import functions as F +from pyspark.sql.types import ( + DoubleType, + FloatType, + StringType, + StructField, + StructType, + TimestampType, +) + +from config import kafka_config, spark_config + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Schema +# --------------------------------------------------------------------------- + +TELEMETRY_SCHEMA = StructType( + [ + StructField("drone_id", StringType(), False), + StructField("latitude", DoubleType(), False), + StructField("longitude", DoubleType(), False), + StructField("altitude", FloatType(), False), + StructField("speed", FloatType(), False), + StructField("heading", FloatType(), False), + StructField("battery_level", FloatType(), False), + StructField("status", StringType(), False), + StructField("timestamp", DoubleType(), True), + StructField("destination_lat", DoubleType(), True), + StructField("destination_lon", DoubleType(), True), + StructField("payload_weight", FloatType(), True), + ] +) + + +# --------------------------------------------------------------------------- +# Session factory +# --------------------------------------------------------------------------- + +def create_spark_session(app_name: Optional[str] = None) -> SparkSession: + """Build and return a :class:`SparkSession` configured for the pipeline.""" + name = app_name or spark_config.SPARK_APP_NAME + builder = ( + SparkSession.builder.appName(name) + .master(spark_config.SPARK_MASTER) + .config("spark.jars.packages", spark_config.SPARK_PACKAGES) + # Kafka source + .config("spark.sql.streaming.schemaInference", "true") + # S3a settings populated from environment + .config( + "spark.hadoop.fs.s3a.aws.credentials.provider", + "com.amazonaws.auth.EnvironmentVariableCredentialsProvider", + ) + ) + return builder.getOrCreate() + + +# --------------------------------------------------------------------------- +# Core processor +# --------------------------------------------------------------------------- + +class TelemetryProcessor: + """ + Reads drone telemetry from Kafka, parses JSON, and writes to S3. + + Typical usage:: + + spark = create_spark_session() + processor = TelemetryProcessor(spark) + query = processor.start() + query.awaitTermination() + """ + + def __init__(self, spark: SparkSession) -> None: + self._spark = spark + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def read_stream(self) -> DataFrame: + """Return a streaming DataFrame of parsed telemetry records.""" + raw = ( + self._spark.readStream.format("kafka") + .option("kafka.bootstrap.servers", kafka_config.KAFKA_BOOTSTRAP_SERVERS) + .option("subscribe", kafka_config.TELEMETRY_TOPIC) + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .load() + ) + + parsed = raw.select( + F.col("key").cast(StringType()).alias("message_key"), + F.from_json(F.col("value").cast(StringType()), TELEMETRY_SCHEMA).alias("data"), + F.col("timestamp").alias("kafka_timestamp"), + ).select( + "message_key", + "kafka_timestamp", + "data.*", + ) + + # Enrich with event-time column + return parsed.withColumn( + "event_time", + F.to_timestamp(F.col("timestamp").cast(DoubleType())), + ) + + def write_to_s3( + self, + df: DataFrame, + s3_path: str, + checkpoint_location: Optional[str] = None, + trigger_interval: Optional[str] = None, + ): + """Write streaming DataFrame to S3 in Parquet format.""" + checkpoint = checkpoint_location or ( + spark_config.STREAMING_CHECKPOINT_LOCATION + "/telemetry" + ) + trigger = trigger_interval or spark_config.STREAMING_TRIGGER_INTERVAL + return ( + df.writeStream.format("parquet") + .option("path", s3_path) + .option("checkpointLocation", checkpoint) + .trigger(processingTime=trigger) + .outputMode(spark_config.STREAMING_OUTPUT_MODE) + .start() + ) + + def write_to_kafka( + self, + df: DataFrame, + topic: str = kafka_config.PROCESSED_TOPIC, + checkpoint_location: Optional[str] = None, + ): + """Forward enriched records to a downstream Kafka topic.""" + checkpoint = checkpoint_location or ( + spark_config.STREAMING_CHECKPOINT_LOCATION + "/processed-kafka" + ) + value_df = df.select( + F.col("drone_id").alias("key"), + F.to_json(F.struct("*")).alias("value"), + ) + return ( + value_df.writeStream.format("kafka") + .option("kafka.bootstrap.servers", kafka_config.KAFKA_BOOTSTRAP_SERVERS) + .option("topic", topic) + .option("checkpointLocation", checkpoint) + .outputMode(spark_config.STREAMING_OUTPUT_MODE) + .start() + ) + + def start( + self, + s3_path: str = "s3a://drone-telemetry-data/processed/", + ): + """Convenience method: start the full processing pipeline.""" + stream_df = self.read_stream() + return self.write_to_s3(stream_df, s3_path) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_collision_detector.py b/tests/test_collision_detector.py new file mode 100644 index 0000000..ac32719 --- /dev/null +++ b/tests/test_collision_detector.py @@ -0,0 +1,88 @@ +"""Unit tests for the Spark CollisionDetector.""" + +import pytest + + +class TestCollisionDetectorWithSpark: + """Integration-style tests using a local SparkSession.""" + + @pytest.fixture(scope="class") + def spark(self): + from pyspark.sql import SparkSession + session = ( + SparkSession.builder.master("local[1]") + .appName("test-collision-detector") + .config("spark.ui.enabled", "false") + .getOrCreate() + ) + session.sparkContext.setLogLevel("ERROR") + yield session + session.stop() + + def _make_df(self, spark, rows): + cols = ["drone_id", "latitude", "longitude", "altitude"] + return spark.createDataFrame(rows, cols) + + def test_no_alerts_when_drones_far_apart(self, spark): + from src.spark.collision_detector import CollisionDetector + # Drones ~12 km apart + rows = [ + ("DRONE-A", 37.7749, -122.4194, 100.0), + ("DRONE-B", 37.8044, -122.2711, 100.0), + ] + df = self._make_df(spark, rows) + detector = CollisionDetector(safe_distance_m=50.0, safe_altitude_m=10.0) + alerts = detector.detect(df) + assert alerts.count() == 0 + + def test_critical_alert_when_very_close(self, spark): + from src.spark.collision_detector import CollisionDetector + # Drones at nearly identical position and altitude + rows = [ + ("DRONE-A", 37.7749, -122.4194, 80.0), + ("DRONE-B", 37.7749, -122.4195, 80.5), # ~8 m away, 0.5 m altitude diff + ] + df = self._make_df(spark, rows) + detector = CollisionDetector(safe_distance_m=50.0, safe_altitude_m=10.0) + alerts = detector.detect(df) + rows_collected = alerts.collect() + assert len(rows_collected) == 1 + assert rows_collected[0]["alert_level"] == "CRITICAL" + assert rows_collected[0]["is_collision_risk"] is True + + def test_no_duplicate_pairs(self, spark): + from src.spark.collision_detector import CollisionDetector + rows = [ + ("DRONE-A", 37.7749, -122.4194, 80.0), + ("DRONE-B", 37.7749, -122.4195, 80.0), + ("DRONE-C", 37.7749, -122.4196, 80.0), + ] + df = self._make_df(spark, rows) + detector = CollisionDetector(safe_distance_m=50.0, safe_altitude_m=10.0) + alerts = detector.detect(df) + # Should have C(3,2)=3 pairs, all critical + assert alerts.count() == 3 + for row in alerts.collect(): + assert row["drone_a"] < row["drone_b"] # no duplicates + + def test_warning_alert_when_only_horizontal_close(self, spark): + from src.spark.collision_detector import CollisionDetector + rows = [ + ("DRONE-A", 37.7749, -122.4194, 80.0), + ("DRONE-B", 37.7749, -122.4195, 200.0), # close horizontally, far vertically + ] + df = self._make_df(spark, rows) + detector = CollisionDetector(safe_distance_m=50.0, safe_altitude_m=10.0) + alerts = detector.detect(df) + rows_collected = alerts.collect() + assert len(rows_collected) == 1 + assert rows_collected[0]["alert_level"] == "WARNING" + assert rows_collected[0]["is_collision_risk"] is False + + def test_single_drone_no_alerts(self, spark): + from src.spark.collision_detector import CollisionDetector + rows = [("DRONE-A", 37.7749, -122.4194, 80.0)] + df = self._make_df(spark, rows) + detector = CollisionDetector() + alerts = detector.detect(df) + assert alerts.count() == 0 diff --git a/tests/test_drone_consumer.py b/tests/test_drone_consumer.py new file mode 100644 index 0000000..1ae2319 --- /dev/null +++ b/tests/test_drone_consumer.py @@ -0,0 +1,116 @@ +"""Unit tests for the Kafka DroneConsumer (mocked confluent_kafka).""" + +from unittest.mock import MagicMock, patch + +import pytest + +from src.models.drone_telemetry import DroneTelemetry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_message(telemetry: DroneTelemetry, error=None): + msg = MagicMock() + msg.error.return_value = error + msg.value.return_value = telemetry.to_json().encode("utf-8") + msg.key.return_value = telemetry.drone_id.encode("utf-8") + return msg + + +def _sentinel_message(): + """A message with no error and no value — simulates a None poll result.""" + return None + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def mock_consumer_cls(): + with patch("src.kafka.drone_consumer.Consumer") as mock_cls: + yield mock_cls + + +@pytest.fixture() +def consumer(mock_consumer_cls): + from src.kafka.drone_consumer import DroneConsumer + return DroneConsumer( + topics=["test-topic"], + config={"bootstrap.servers": "localhost:9092", "group.id": "test-group", "auto.offset.reset": "earliest", "enable.auto.commit": False}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDroneConsumer: + def test_subscribe_called_on_consume(self, consumer, mock_consumer_cls): + telemetry = DroneTelemetry( + drone_id="DRONE-001", latitude=37.77, longitude=-122.42, + altitude=80.0, speed=10.0, heading=90.0, battery_level=75.0, status="en_route", + ) + mock_instance = mock_consumer_cls.return_value + mock_instance.poll.side_effect = [_make_message(telemetry), None, None] + + received = [] + consumer.consume(handler=received.append, max_messages=1) + + mock_instance.subscribe.assert_called_once_with(["test-topic"]) + + def test_handler_called_with_deserialized_telemetry(self, consumer, mock_consumer_cls): + telemetry = DroneTelemetry( + drone_id="DRONE-002", latitude=37.77, longitude=-122.42, + altitude=80.0, speed=10.0, heading=90.0, battery_level=60.0, status="hovering", + ) + mock_instance = mock_consumer_cls.return_value + mock_instance.poll.side_effect = [_make_message(telemetry), None] + + received = [] + consumer.consume(handler=received.append, max_messages=1) + + assert len(received) == 1 + assert received[0].drone_id == "DRONE-002" + assert received[0].battery_level == 60.0 + + def test_malformed_message_is_skipped(self, consumer, mock_consumer_cls): + bad_msg = MagicMock() + bad_msg.error.return_value = None + bad_msg.value.return_value = b"not valid json" + + good_telemetry = DroneTelemetry( + drone_id="DRONE-003", latitude=37.77, longitude=-122.42, + altitude=80.0, speed=10.0, heading=90.0, battery_level=50.0, status="en_route", + ) + mock_instance = mock_consumer_cls.return_value + mock_instance.poll.side_effect = [bad_msg, _make_message(good_telemetry)] + + received = [] + consumer.consume(handler=received.append, max_messages=1) + + assert len(received) == 1 + assert received[0].drone_id == "DRONE-003" + + def test_commit_called_after_successful_message(self, consumer, mock_consumer_cls): + telemetry = DroneTelemetry( + drone_id="DRONE-004", latitude=37.77, longitude=-122.42, + altitude=80.0, speed=10.0, heading=90.0, battery_level=80.0, status="en_route", + ) + mock_instance = mock_consumer_cls.return_value + msg = _make_message(telemetry) + mock_instance.poll.side_effect = [msg] + + consumer.consume(handler=lambda t: None, max_messages=1) + + mock_instance.commit.assert_called_once_with(message=msg, asynchronous=False) + + def test_close_called_on_exit(self, consumer, mock_consumer_cls): + mock_instance = mock_consumer_cls.return_value + mock_instance.poll.return_value = None + + consumer.consume(handler=lambda t: None, max_messages=0) + + mock_instance.close.assert_called_once() diff --git a/tests/test_drone_producer.py b/tests/test_drone_producer.py new file mode 100644 index 0000000..fa44805 --- /dev/null +++ b/tests/test_drone_producer.py @@ -0,0 +1,95 @@ +"""Unit tests for the Kafka DroneProducer (mocked confluent_kafka).""" + +import json +from unittest.mock import MagicMock, call, patch + +import pytest + +from src.models.drone_telemetry import DroneTelemetry + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def mock_producer_cls(): + with patch("src.kafka.drone_producer.Producer") as mock_cls: + yield mock_cls + + +@pytest.fixture() +def mock_admin_client(): + with patch("src.kafka.drone_producer.AdminClient") as mock_cls: + topics_mock = MagicMock() + topics_mock.topics = {} + mock_cls.return_value.list_topics.return_value = topics_mock + mock_cls.return_value.create_topics.return_value = {} + yield mock_cls + + +@pytest.fixture() +def producer(mock_producer_cls): + from src.kafka.drone_producer import DroneProducer + return DroneProducer(config={"bootstrap.servers": "localhost:9092"}, topic="test-topic") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestDroneProducer: + def test_send_calls_produce(self, producer, mock_producer_cls): + telemetry = DroneTelemetry( + drone_id="DRONE-001", + latitude=37.77, + longitude=-122.42, + altitude=80.0, + speed=10.0, + heading=90.0, + battery_level=75.0, + status="en_route", + ) + producer.send(telemetry) + mock_instance = mock_producer_cls.return_value + mock_instance.produce.assert_called_once() + call_kwargs = mock_instance.produce.call_args.kwargs + assert call_kwargs["topic"] == "test-topic" + assert call_kwargs["key"] == b"DRONE-001" + payload = json.loads(call_kwargs["value"].decode("utf-8")) + assert payload["drone_id"] == "DRONE-001" + assert payload["battery_level"] == 75.0 + + def test_flush_delegates_to_confluent(self, producer, mock_producer_cls): + mock_producer_cls.return_value.flush.return_value = 0 + result = producer.flush(timeout=5.0) + mock_producer_cls.return_value.flush.assert_called_once_with(5.0) + assert result == 0 + + def test_simulate_publishes_correct_number_of_messages(self, producer, mock_producer_cls): + drone_ids = ["DRONE-001", "DRONE-002"] + max_messages = 4 # 2 drones × 2 rounds + producer.simulate(drone_ids=drone_ids, interval_seconds=0, max_messages=max_messages) + mock_instance = mock_producer_cls.return_value + assert mock_instance.produce.call_count == max_messages + + def test_simulate_uses_all_drone_ids(self, producer, mock_producer_cls): + drone_ids = ["DRONE-A", "DRONE-B", "DRONE-C"] + producer.simulate(drone_ids=drone_ids, interval_seconds=0, max_messages=3) + produced_keys = { + c.kwargs["key"] for c in mock_producer_cls.return_value.produce.call_args_list + } + assert produced_keys == {b"DRONE-A", b"DRONE-B", b"DRONE-C"} + + +class TestEnsureTopics: + def test_creates_missing_topics(self, mock_admin_client): + from src.kafka.drone_producer import ensure_topics + ensure_topics(["new-topic"]) + mock_admin_client.return_value.create_topics.assert_called_once() + + def test_skips_existing_topics(self, mock_admin_client): + mock_admin_client.return_value.list_topics.return_value.topics = {"existing": MagicMock()} + from src.kafka.drone_producer import ensure_topics + ensure_topics(["existing"]) + mock_admin_client.return_value.create_topics.assert_not_called() diff --git a/tests/test_drone_telemetry.py b/tests/test_drone_telemetry.py new file mode 100644 index 0000000..36017f5 --- /dev/null +++ b/tests/test_drone_telemetry.py @@ -0,0 +1,84 @@ +"""Unit tests for DroneTelemetry data model.""" + +import json +import time + +import pytest + +from src.models.drone_telemetry import DroneTelemetry + + +@pytest.fixture() +def sample_telemetry() -> DroneTelemetry: + return DroneTelemetry( + drone_id="DRONE-001", + latitude=37.7749, + longitude=-122.4194, + altitude=80.0, + speed=10.0, + heading=90.0, + battery_level=75.0, + status="en_route", + destination_lat=37.7800, + destination_lon=-122.4100, + payload_weight=1.5, + ) + + +class TestDroneTelemetrySerialization: + def test_to_dict_contains_all_fields(self, sample_telemetry): + d = sample_telemetry.to_dict() + assert d["drone_id"] == "DRONE-001" + assert d["latitude"] == 37.7749 + assert d["battery_level"] == 75.0 + assert "timestamp" in d + + def test_to_json_is_valid_json(self, sample_telemetry): + raw = sample_telemetry.to_json() + parsed = json.loads(raw) + assert parsed["drone_id"] == "DRONE-001" + + def test_from_dict_roundtrip(self, sample_telemetry): + d = sample_telemetry.to_dict() + restored = DroneTelemetry.from_dict(d) + assert restored.drone_id == sample_telemetry.drone_id + assert restored.latitude == sample_telemetry.latitude + assert restored.status == sample_telemetry.status + + def test_from_json_roundtrip(self, sample_telemetry): + restored = DroneTelemetry.from_json(sample_telemetry.to_json()) + assert restored.drone_id == sample_telemetry.drone_id + assert restored.heading == sample_telemetry.heading + + +class TestDroneTelemetryHelpers: + def test_low_battery_false_when_above_threshold(self, sample_telemetry): + assert sample_telemetry.is_low_battery(threshold=20.0) is False + + def test_low_battery_true_when_below_threshold(self): + t = DroneTelemetry( + drone_id="X", latitude=0, longitude=0, altitude=0, + speed=0, heading=0, battery_level=15.0, status="returning", + ) + assert t.is_low_battery(threshold=20.0) is True + + def test_is_emergency_true(self): + t = DroneTelemetry( + drone_id="X", latitude=0, longitude=0, altitude=0, + speed=0, heading=0, battery_level=3.0, status="emergency", + ) + assert t.is_emergency() is True + + def test_is_emergency_false(self, sample_telemetry): + assert sample_telemetry.is_emergency() is False + + def test_default_timestamp_is_recent(self, sample_telemetry): + assert sample_telemetry.timestamp == pytest.approx(time.time(), abs=5.0) + + def test_optional_destination_none_by_default(self): + t = DroneTelemetry( + drone_id="X", latitude=0, longitude=0, altitude=0, + speed=0, heading=0, battery_level=50.0, status="hovering", + ) + assert t.destination_lat is None + assert t.destination_lon is None diff --git a/tests/test_redshift_handler.py b/tests/test_redshift_handler.py new file mode 100644 index 0000000..8517e08 --- /dev/null +++ b/tests/test_redshift_handler.py @@ -0,0 +1,116 @@ +"""Unit tests for RedshiftHandler (mocked psycopg2).""" + +from unittest.mock import MagicMock, call, patch + +import pytest + +from src.aws.redshift_handler import RedshiftHandler + + +@pytest.fixture() +def mock_psycopg2(): + with patch("src.aws.redshift_handler.psycopg2") as mock_pg: + mock_conn = MagicMock() + mock_conn.closed = False + mock_pg.connect.return_value = mock_conn + yield mock_pg, mock_conn + + +@pytest.fixture() +def handler(mock_psycopg2): + _, _ = mock_psycopg2 + return RedshiftHandler( + host="localhost", + port=5439, + database="drone_analytics", + user="admin", + password="secret", + ) + + +class TestRedshiftHandlerConnection: + def test_connect_calls_psycopg2_connect(self, handler, mock_psycopg2): + mock_pg, mock_conn = mock_psycopg2 + handler.connect() + mock_pg.connect.assert_called_once() + call_kwargs = mock_pg.connect.call_args.kwargs + assert call_kwargs["host"] == "localhost" + assert call_kwargs["dbname"] == "drone_analytics" + + def test_disconnect_closes_connection(self, handler, mock_psycopg2): + mock_pg, mock_conn = mock_psycopg2 + handler.connect() + handler.disconnect() + mock_conn.close.assert_called_once() + + +class TestRedshiftHandlerInserts: + def _setup_cursor(self, mock_psycopg2): + mock_pg, mock_conn = mock_psycopg2 + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + return mock_cursor + + def test_insert_telemetry_batch_returns_count(self, handler, mock_psycopg2): + with patch("src.aws.redshift_handler.execute_values") as mock_ev: + self._setup_cursor(mock_psycopg2) + records = [ + { + "drone_id": "DRONE-001", + "latitude": 37.77, + "longitude": -122.42, + "altitude": 80.0, + "speed": 10.0, + "heading": 90.0, + "battery_level": 75.0, + "status": "en_route", + "event_time": None, + "payload_weight": 1.5, + } + ] + count = handler.insert_telemetry_batch(records) + assert count == 1 + mock_ev.assert_called_once() + + def test_insert_telemetry_batch_empty_returns_zero(self, handler, mock_psycopg2): + count = handler.insert_telemetry_batch([]) + assert count == 0 + + def test_insert_alerts_batch_returns_count(self, handler, mock_psycopg2): + with patch("src.aws.redshift_handler.execute_values") as mock_ev: + self._setup_cursor(mock_psycopg2) + alerts = [ + { + "drone_a": "DRONE-001", + "drone_b": "DRONE-002", + "horizontal_distance_m": 30.0, + "vertical_distance_m": 5.0, + "is_collision_risk": True, + "alert_level": "CRITICAL", + } + ] + count = handler.insert_alerts_batch(alerts) + assert count == 1 + + def test_insert_routes_batch_empty_returns_zero(self, handler, mock_psycopg2): + count = handler.insert_routes_batch([]) + assert count == 0 + + def test_initialise_schema_executes_ddl(self, handler, mock_psycopg2): + mock_cursor = self._setup_cursor(mock_psycopg2) + handler.initialise_schema() + # Three CREATE TABLE statements + assert mock_cursor.execute.call_count == 3 + + +class TestRedshiftHandlerQueries: + def test_get_low_battery_drones(self, handler, mock_psycopg2): + mock_pg, mock_conn = mock_psycopg2 + mock_cursor = MagicMock() + mock_cursor.description = [("drone_id",), ("battery_level",), ("status",), ("event_time",)] + mock_cursor.fetchall.return_value = [("DRONE-005", 12.0, "returning", None)] + mock_conn.cursor.return_value = mock_cursor + results = handler.get_low_battery_drones(threshold=20.0) + assert len(results) == 1 + assert results[0]["drone_id"] == "DRONE-005" + assert results[0]["battery_level"] == 12.0 diff --git a/tests/test_route_optimizer.py b/tests/test_route_optimizer.py new file mode 100644 index 0000000..5e9fbdf --- /dev/null +++ b/tests/test_route_optimizer.py @@ -0,0 +1,110 @@ +"""Unit tests for the Spark route optimiser (no live Spark cluster required).""" + +import math + +import pytest + +from src.spark.route_optimizer import RouteOptimizer, _haversine_distance, _recommend_action + + +class TestHaversineDistance: + def test_same_point_is_zero(self): + assert _haversine_distance(37.77, -122.42, 37.77, -122.42) == pytest.approx(0.0, abs=1e-3) + + def test_known_distance(self): + # San Francisco -> Oakland (~12 km across the bay) + dist = _haversine_distance(37.7749, -122.4194, 37.8044, -122.2711) + assert 11_000 < dist < 14_000 + + def test_none_coordinates_returns_negative(self): + assert _haversine_distance(None, -122.42, 37.77, -122.42) == -1.0 + + def test_symmetry(self): + d1 = _haversine_distance(37.77, -122.42, 37.80, -122.40) + d2 = _haversine_distance(37.80, -122.40, 37.77, -122.42) + assert d1 == pytest.approx(d2, rel=1e-6) + + +class TestRecommendAction: + def test_emergency_land_on_critical_battery(self): + assert _recommend_action(3.0, 1000.0, "en_route", 1.0) == "emergency_land" + + def test_emergency_land_on_emergency_status(self): + assert _recommend_action(50.0, 1000.0, "emergency", 1.0) == "emergency_land" + + def test_return_to_base_on_low_battery(self): + assert _recommend_action(15.0, 5000.0, "en_route", 1.0) == "return_to_base" + + def test_reduce_speed_on_heavy_payload(self): + # 10 kg > MAX_PAYLOAD_WEIGHT_KG (5 kg default) + result = _recommend_action(80.0, 500.0, "en_route", 10.0) + assert result == "reduce_speed" + + def test_continue_when_normal(self): + assert _recommend_action(80.0, 500.0, "en_route", 1.0) == "continue" + + def test_optimal_when_distance_zero(self): + assert _recommend_action(80.0, 0.0, "hovering", 0.0) == "optimal" + + def test_unknown_on_none_inputs(self): + assert _recommend_action(None, 0.0, None, 0.0) == "unknown" + + +class TestRouteOptimizerWithSpark: + """Integration-style tests using a local SparkSession.""" + + @pytest.fixture(scope="class") + def spark(self): + from pyspark.sql import SparkSession + session = ( + SparkSession.builder.master("local[1]") + .appName("test-route-optimizer") + .config("spark.ui.enabled", "false") + .getOrCreate() + ) + session.sparkContext.setLogLevel("ERROR") + yield session + session.stop() + + @pytest.fixture() + def sample_df(self, spark): + data = [ + ("DRONE-001", 37.7749, -122.4194, 80.0, 10.0, 90.0, 75.0, "en_route", 1.5, 37.780, -122.410), + ("DRONE-002", 37.7800, -122.4100, 60.0, 5.0, 180.0, 15.0, "returning", 1.0, 37.790, -122.420), + ("DRONE-003", 37.7760, -122.4180, 70.0, 0.0, 0.0, 4.0, "emergency", 0.0, 37.780, -122.415), + ] + cols = ["drone_id", "latitude", "longitude", "altitude", "speed", "heading", + "battery_level", "status", "payload_weight", "destination_lat", "destination_lon"] + return spark.createDataFrame(data, cols) + + def test_optimise_adds_required_columns(self, sample_df): + optimizer = RouteOptimizer() + result = optimizer.optimise(sample_df) + expected_cols = {"distance_to_dest_m", "estimated_flight_time_s", + "recommended_action", "optimisation_score"} + assert expected_cols.issubset(set(result.columns)) + + def test_distance_positive_for_valid_destinations(self, sample_df): + optimizer = RouteOptimizer() + result = optimizer.optimise(sample_df) + rows = {r["drone_id"]: r["distance_to_dest_m"] for r in result.collect()} + assert rows["DRONE-001"] > 0 + assert rows["DRONE-002"] > 0 + + def test_emergency_drone_gets_emergency_land_action(self, sample_df): + optimizer = RouteOptimizer() + result = optimizer.optimise(sample_df) + rows = {r["drone_id"]: r["recommended_action"] for r in result.collect()} + assert rows["DRONE-003"] == "emergency_land" + + def test_low_battery_drone_gets_return_action(self, sample_df): + optimizer = RouteOptimizer() + result = optimizer.optimise(sample_df) + rows = {r["drone_id"]: r["recommended_action"] for r in result.collect()} + assert rows["DRONE-002"] == "return_to_base" + + def test_optimisation_score_in_range(self, sample_df): + optimizer = RouteOptimizer() + result = optimizer.optimise(sample_df) + for row in result.collect(): + assert 0.0 <= row["optimisation_score"] <= 100.0 diff --git a/tests/test_s3_handler.py b/tests/test_s3_handler.py new file mode 100644 index 0000000..f4e62f6 --- /dev/null +++ b/tests/test_s3_handler.py @@ -0,0 +1,82 @@ +"""Unit tests for S3Handler (mocked boto3).""" + +import json +from unittest.mock import MagicMock, call, patch + +import pytest + +from src.aws.s3_handler import S3Handler + + +@pytest.fixture() +def mock_boto3_client(): + with patch("src.aws.s3_handler.boto3") as mock_boto3: + yield mock_boto3 + + +@pytest.fixture() +def handler(mock_boto3_client): + return S3Handler(bucket="test-bucket", region="us-east-1") + + +class TestS3HandlerUpload: + def test_upload_json_records_calls_put_object(self, handler, mock_boto3_client): + records = [{"drone_id": "DRONE-001", "battery_level": 75.0}] + key = handler.upload_json_records(records, prefix="raw/", partition_by_date=False) + mock_client = mock_boto3_client.client.return_value + mock_client.put_object.assert_called_once() + call_kwargs = mock_client.put_object.call_args.kwargs + assert call_kwargs["Bucket"] == "test-bucket" + assert "raw/" in call_kwargs["Key"] + body_lines = call_kwargs["Body"].decode("utf-8").strip().split("\n") + assert len(body_lines) == 1 + assert json.loads(body_lines[0])["drone_id"] == "DRONE-001" + + def test_upload_json_records_returns_key(self, handler, mock_boto3_client): + key = handler.upload_json_records([{"x": 1}], prefix="raw/", partition_by_date=False) + assert key.startswith("raw/") + + def test_upload_file_calls_upload_file(self, handler, mock_boto3_client): + handler.upload_file("/tmp/test.parquet", "processed/test.parquet") + mock_client = mock_boto3_client.client.return_value + mock_client.upload_file.assert_called_once_with( + "/tmp/test.parquet", "test-bucket", "processed/test.parquet" + ) + + def test_partition_by_date_adds_year_month_day(self, handler, mock_boto3_client): + handler.upload_json_records([{"a": 1}], prefix="raw/", partition_by_date=True) + mock_client = mock_boto3_client.client.return_value + key = mock_client.put_object.call_args.kwargs["Key"] + assert "year=" in key + assert "month=" in key + assert "day=" in key + + +class TestS3HandlerDownload: + def test_download_returns_list_of_dicts(self, handler, mock_boto3_client): + body_content = '{"drone_id": "DRONE-001"}\n{"drone_id": "DRONE-002"}' + mock_client = mock_boto3_client.client.return_value + mock_client.get_object.return_value = { + "Body": MagicMock(read=MagicMock(return_value=body_content.encode("utf-8"))) + } + result = handler.download_json_records("raw/data.ndjson") + assert len(result) == 2 + assert result[0]["drone_id"] == "DRONE-001" + assert result[1]["drone_id"] == "DRONE-002" + + +class TestS3HandlerListDelete: + def test_list_objects_returns_keys(self, handler, mock_boto3_client): + mock_client = mock_boto3_client.client.return_value + mock_client.get_paginator.return_value.paginate.return_value = [ + {"Contents": [{"Key": "raw/file1.json"}, {"Key": "raw/file2.json"}]} + ] + keys = handler.list_objects("raw/") + assert keys == ["raw/file1.json", "raw/file2.json"] + + def test_delete_object_calls_delete(self, handler, mock_boto3_client): + handler.delete_object("raw/old.json") + mock_client = mock_boto3_client.client.return_value + mock_client.delete_object.assert_called_once_with( + Bucket="test-bucket", Key="raw/old.json" + )