Ray

Ray Learnings

Production-ready Ray distributed computing implementations covering remote functions, stateful actors, parallel execution patterns, and RabbitMQ-driven task ingestion with Python.

What You Will Learn

  • How Ray distributes work across local CPU cores (and remote nodes)
  • Remote functions and object references β€” Ray's core abstraction
  • Stateful actors: long-lived worker objects that serialise concurrent access
  • Parallel execution: ray.get(), ray.wait(), and non-blocking patterns
  • Integrating RabbitMQ as a task source: fan-out to Ray remote workers

Prerequisites

  • Python 3.13+
  • RabbitMQ running via the shared message_brokers/ Docker Compose stack (for ray_worker.py and ray_producer.py)

Setup

cd ray_learnings
python3 -m venv venv
source venv/bin/activate       # Windows: venv\Scripts\activate
pip install -r requirements.txt

Start RabbitMQ (for RabbitMQ integration labs)

# From the message_brokers/ directory
docker compose up -d rabbitmq

Running the Labs

All scripts live in python/. Run from the python/ directory:

cd ray_learnings/python

Standalone Ray demo (no RabbitMQ needed)

python ray_basics.py

RabbitMQ + Ray integration

# Terminal 1 β€” start the Ray worker (consumes RabbitMQ, fans out to Ray)
python ray_worker.py

# Terminal 2 β€” publish tasks to RabbitMQ
python ray_producer.py

Lab Implementation & Engineering Deep Dives

1. Core Ray Concepts (python/ray_basics.py)

  • Why: Demonstrates all core Ray abstractions in a single standalone script.
  • What: Remote functions, parallel fan-out, stateful actors, ray.wait() for non-blocking progress, and dependent pipelines.
  • How: ray.init(num_cpus=4) spins up a local cluster. @ray.remote converts functions and classes. ray.get() collects futures.

2. RabbitMQ Consumer + Ray Worker (python/ray_worker.py)

  • Why: Shows the pattern of using a message broker as a task source for Ray.
  • What: Connects to RabbitMQ, consumes tasks from ray_tasks queue, dispatches each to a Ray remote function, publishes results to ray_results queue.
  • How: pika blocking consumer + ray.get() per message. basic_qos(prefetch_count=4) limits in-flight messages.

3. Task Producer (python/ray_producer.py)

  • Why: Publishes a batch of typed tasks to RabbitMQ and collects results.
  • What: Publishes compute and slow task messages, then polls ray_results until all results arrive.
  • How: pika basic_publish with delivery_mode=2 (persistent). basic_get polling loop for result collection.

πŸ“ Lab Implementation & Scripts

python/ray_basics.py

"""
ray_basics.py β€” Core Ray concepts on a local cluster.

Demonstrates:
    1. Remote functions and object references
    2. Parallel fan-out and speedup
    3. Actors β€” stateful remote objects
    4. ray.wait() for non-blocking progress tracking
    5. Dependent task pipelines

Does NOT require RabbitMQ. Runs standalone.

Run (from ray_learnings/python/):
    python ray_basics.py
"""

import time
import ray

ray.init(num_cpus=4, ignore_reinit_error=True)

SEP = "─" * 60


def section(title: str) -> None:
    print(f"\n{SEP}\n  {title}\n{SEP}")


# ── Remote Functions ───────────────────────────────────────────────────────────

@ray.remote
def add(x: int, y: int) -> int:
    """Remote addition β€” executes in a Ray worker process."""
    return x + y


@ray.remote
def cpu_bound(n: int) -> int:
    """Simulates CPU-intensive work: sum of squares."""
    return sum(i * i for i in range(n))


@ray.remote
def slow_task(job_id: str, duration: float = 0.5) -> dict:
    """Simulates a slow external call."""
    time.sleep(duration)
    return {"job_id": job_id, "status": "done", "duration": duration}


# ── Stateful Actor ─────────────────────────────────────────────────────────────

@ray.remote
class Counter:
    """
    A stateful Ray actor.
    Actors run in their own process and serialise concurrent method calls.
    Use actors for shared mutable state across distributed workers.
    """

    def __init__(self, name: str) -> None:
        self.name = name
        self.count = 0

    def increment(self, by: int = 1) -> int:
        self.count += by
        return self.count

    def reset(self) -> None:
        self.count = 0

    def value(self) -> int:
        return self.count


# ── Demos ──────────────────────────────────────────────────────────────────────

def demo_basic_remote() -> None:
    section("1. Basic Remote Function")
    # .remote() dispatches immediately, returns an ObjectRef (future)
    ref = add.remote(10, 20)
    result = ray.get(ref)           # blocks until result is available
    print(f"add.remote(10, 20) = {result}")

    # Multiple concurrent calls
    refs = [add.remote(i, i) for i in range(5)]
    results = ray.get(refs)         # waits for all
    print(f"add.remote(i, i) for i in 0..4 = {results}")


def demo_parallel_speedup() -> None:
    section("2. Parallel Speedup")
    n = 2_000_000
    t0 = time.perf_counter()
    refs = [cpu_bound.remote(n) for _ in range(4)]
    ray.get(refs)
    elapsed = time.perf_counter() - t0
    print(f"4 Γ— cpu_bound({n:,}) in parallel: {elapsed:.2f}s")


def demo_actor() -> None:
    section("3. Actor β€” Stateful Remote Object")
    counter = Counter.remote("demo-counter")

    # All calls return ObjectRefs immediately (non-blocking)
    increment_refs = [counter.increment.remote(i) for i in range(1, 6)]
    values = ray.get(increment_refs)
    print(f"Increments: {values}")          # [1, 3, 6, 10, 15]

    total = ray.get(counter.value.remote())
    print(f"Total: {total}")                # 15

    ray.get(counter.reset.remote())
    print(f"After reset: {ray.get(counter.value.remote())}")  # 0


def demo_ray_wait() -> None:
    section("4. ray.wait() β€” Process Results as They Arrive")
    durations = [0.1, 0.4, 0.2, 0.6, 0.3]
    refs = [slow_task.remote(f"job-{i}", d) for i, d in enumerate(durations)]

    remaining = refs[:]
    completed = 0
    while remaining:
        done, remaining = ray.wait(remaining, num_returns=1, timeout=2.0)
        for ref in done:
            result = ray.get(ref)
            completed += 1
            print(f"  [{completed}/5] {result}")


def demo_pipeline() -> None:
    section("5. Dependent Pipeline")
    # add(3, 4) β†’ cpu_bound(result) β†’ slow_task(str(result))
    ref1 = add.remote(3, 4)               # 7
    val1 = ray.get(ref1)
    ref2 = cpu_bound.remote(val1)
    val2 = ray.get(ref2)
    ref3 = slow_task.remote(f"pipe-{val2}", duration=0.1)
    final = ray.get(ref3)
    print(f"Pipeline: add(3,4)→cpu_bound→slow_task = {final}")


if __name__ == "__main__":
    print("Ray Basics β€” local cluster (4 CPUs)")
    print("Dashboard: http://127.0.0.1:8265 (if enabled)\n")

    demo_basic_remote()
    demo_parallel_speedup()
    demo_actor()
    demo_ray_wait()
    demo_pipeline()

    print(f"\n{SEP}\n  All demos complete.\n{SEP}")
    ray.shutdown()

python/ray_worker.py

"""
ray_worker.py β€” RabbitMQ consumer that fans tasks out to Ray remote functions.

Architecture:
    RabbitMQ (AMQP) ──► pika consumer ──► dispatch() ──► Ray remote functions
                                                      ──► result β†’ AMQP result queue

The pika consumer handles queue I/O on the main thread.
Ray remote functions execute in parallel across available CPU cores.
ray.get() collects results synchronously before acking the AMQP message.

Prerequisites:
    RabbitMQ running:  docker compose up -d rabbitmq  (in message_brokers/)
    pip install -r requirements.txt

Run (Terminal 1, from ray_learnings/python/):
    python ray_worker.py

Then (Terminal 2):
    python ray_producer.py
"""

import json
import time
import pika
import ray

BROKER_HOST = "localhost"
TASK_QUEUE = "ray_tasks"
RESULT_QUEUE = "ray_results"


# ── Ray Remote Functions ───────────────────────────────────────────────────────

@ray.remote
def compute_task(n: int) -> int:
    """CPU-bound computation: sum of squares up to n."""
    return sum(i * i for i in range(n))


@ray.remote
def slow_task(job_id: str, duration: float) -> dict:
    """Simulates a slow I/O-bound job."""
    time.sleep(duration)
    return {"job_id": job_id, "status": "done", "duration": duration}


# ── AMQP Helpers ───────────────────────────────────────────────────────────────

def connect() -> tuple[pika.BlockingConnection, pika.channel.Channel]:
    params = pika.ConnectionParameters(
        host=BROKER_HOST,
        credentials=pika.PlainCredentials("guest", "guest"),
        heartbeat=60,
    )
    conn = pika.BlockingConnection(params)
    ch = conn.channel()
    ch.queue_declare(queue=TASK_QUEUE, durable=True)
    ch.queue_declare(queue=RESULT_QUEUE, durable=True)
    ch.basic_qos(prefetch_count=4)  # up to 4 unacknowledged messages at once
    return conn, ch


def publish_result(ch: pika.channel.Channel, task_id: str, payload: dict) -> None:
    ch.basic_publish(
        exchange="",
        routing_key=RESULT_QUEUE,
        body=json.dumps({"task_id": task_id, **payload}),
        properties=pika.BasicProperties(
            delivery_mode=2,  # persistent
            content_type="application/json",
        ),
    )


# ── Task Dispatch ──────────────────────────────────────────────────────────────

def dispatch(message: dict) -> ray.ObjectRef:
    """Route an AMQP message to the appropriate Ray remote function."""
    task_type = message.get("type")
    payload = message.get("payload", {})

    if task_type == "compute":
        return compute_task.remote(payload["n"])
    elif task_type == "slow":
        return slow_task.remote(payload["job_id"], payload["duration"])
    else:
        raise ValueError(f"Unknown task type: {task_type!r}")


# ── Message Handler ────────────────────────────────────────────────────────────

def on_message(
    ch: pika.channel.Channel,
    method: pika.spec.Basic.Deliver,
    _props,
    body: bytes,
) -> None:
    message = json.loads(body)
    task_id = message.get("id", "unknown")
    print(f"[worker] Received {task_id}: type={message.get('type')}")

    try:
        ref = dispatch(message)
        result = ray.get(ref, timeout=30)
        print(f"[worker] Completed {task_id}: {result}")
        publish_result(ch, task_id, {"result": result})
    except Exception as exc:
        print(f"[worker] Failed {task_id}: {exc}")
        publish_result(ch, task_id, {"error": str(exc)})

    ch.basic_ack(delivery_tag=method.delivery_tag)


# ── Main ───────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    print("Initialising Ray (local cluster, 4 CPUs)...")
    ray.init(num_cpus=4, ignore_reinit_error=True)

    print(f"Connecting to RabbitMQ at {BROKER_HOST}:5672...")
    conn, ch = connect()
    ch.basic_consume(queue=TASK_QUEUE, on_message_callback=on_message)

    print(f"[worker] Listening on '{TASK_QUEUE}'. Press Ctrl+C to stop.\n")
    try:
        ch.start_consuming()
    except KeyboardInterrupt:
        print("\n[worker] Shutting down...")
        ch.stop_consuming()
        conn.close()
        ray.shutdown()

python/ray_producer.py

"""
ray_producer.py β€” Publishes tasks to RabbitMQ for ray_worker.py to process.

Demonstrates the producer side of a Ray + RabbitMQ integration:
    - Publishes typed task messages (compute, slow) to RabbitMQ
    - Polls the result queue until all results are collected
    - Shows task IDs mapped to results

Prerequisites:
    RabbitMQ running:         docker compose up -d rabbitmq  (in message_brokers/)
    Ray worker listening:     python ray_worker.py  (Terminal 1)

Run (Terminal 2, from ray_learnings/python/):
    python ray_producer.py
"""

import json
import time
import uuid
import pika

BROKER_HOST = "localhost"
TASK_QUEUE = "ray_tasks"
RESULT_QUEUE = "ray_results"

SEP = "─" * 60


def section(title: str) -> None:
    print(f"\n{SEP}\n  {title}\n{SEP}")


# ── AMQP Helpers ───────────────────────────────────────────────────────────────

def connect() -> tuple[pika.BlockingConnection, pika.channel.Channel]:
    params = pika.ConnectionParameters(
        host=BROKER_HOST,
        credentials=pika.PlainCredentials("guest", "guest"),
        heartbeat=60,
    )
    conn = pika.BlockingConnection(params)
    ch = conn.channel()
    ch.queue_declare(queue=TASK_QUEUE, durable=True)
    ch.queue_declare(queue=RESULT_QUEUE, durable=True)
    return conn, ch


def publish_task(ch: pika.channel.Channel, task_type: str, payload: dict) -> str:
    task_id = str(uuid.uuid4())[:8]
    message = json.dumps({"id": task_id, "type": task_type, "payload": payload})
    ch.basic_publish(
        exchange="",
        routing_key=TASK_QUEUE,
        body=message,
        properties=pika.BasicProperties(
            delivery_mode=2,  # persistent β€” survives broker restart
            content_type="application/json",
        ),
    )
    return task_id


def collect_results(
    ch: pika.channel.Channel, expected: int, timeout: float = 30.0
) -> list[dict]:
    results = []
    deadline = time.time() + timeout
    while len(results) < expected and time.time() < deadline:
        method, _, body = ch.basic_get(queue=RESULT_QUEUE, auto_ack=True)
        if method:
            results.append(json.loads(body))
        else:
            time.sleep(0.1)
    return results


# ── Main ───────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    print("Ray Producer β€” publishing tasks to RabbitMQ")
    print("Ensure ray_worker.py is running in another terminal.\n")

    conn, ch = connect()
    all_task_ids: list[str] = []

    # 1. Publish compute tasks
    section("1. Publish 8 Compute Tasks")
    for i in range(1, 9):
        n = i * 200_000
        tid = publish_task(ch, "compute", {"n": n})
        all_task_ids.append(tid)
        print(f"  Published {tid}: compute(n={n:,})")

    # 2. Publish slow tasks
    section("2. Publish 4 Slow Tasks")
    for i in range(4):
        tid = publish_task(ch, "slow", {"job_id": f"job-{i}", "duration": 0.3})
        all_task_ids.append(tid)
        print(f"  Published {tid}: slow_task(job_id=job-{i}, duration=0.3s)")

    total = len(all_task_ids)
    print(f"\nPublished {total} tasks total. Waiting for results...")

    # 3. Collect results
    section(f"3. Collecting {total} Results (timeout=30s)")
    time.sleep(0.5)  # brief pause to let worker start consuming
    results = collect_results(ch, expected=total, timeout=30.0)

    print(f"Received {len(results)}/{total} results:")
    for r in sorted(results, key=lambda x: x.get("task_id", "")):
        print(f"  {r}")

    conn.close()
    print(f"\n{SEP}\n  Producer done.\n{SEP}")