Ray
Ray Learnings
Production-ready Ray distributed computing implementations covering remote functions, stateful actors, parallel execution patterns, and RabbitMQ-driven task ingestion with Python.
What You Will Learn
- How Ray distributes work across local CPU cores (and remote nodes)
- Remote functions and object references β Ray's core abstraction
- Stateful actors: long-lived worker objects that serialise concurrent access
- Parallel execution: ray.get(), ray.wait(), and non-blocking patterns
- Integrating RabbitMQ as a task source: fan-out to Ray remote workers
Prerequisites
- Python 3.13+
- RabbitMQ running via the shared
message_brokers/Docker Compose stack (for ray_worker.py and ray_producer.py)
Setup
cd ray_learnings
python3 -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
pip install -r requirements.txtStart RabbitMQ (for RabbitMQ integration labs)
# From the message_brokers/ directory
docker compose up -d rabbitmqRunning the Labs
All scripts live in python/. Run from the python/ directory:
cd ray_learnings/pythonStandalone Ray demo (no RabbitMQ needed)
python ray_basics.pyRabbitMQ + Ray integration
# Terminal 1 β start the Ray worker (consumes RabbitMQ, fans out to Ray)
python ray_worker.py
# Terminal 2 β publish tasks to RabbitMQ
python ray_producer.pyLab Implementation & Engineering Deep Dives
1. Core Ray Concepts (python/ray_basics.py)
- Why: Demonstrates all core Ray abstractions in a single standalone script.
- What: Remote functions, parallel fan-out, stateful actors, ray.wait() for non-blocking progress, and dependent pipelines.
- How:
ray.init(num_cpus=4)spins up a local cluster.@ray.remoteconverts functions and classes.ray.get()collects futures.
2. RabbitMQ Consumer + Ray Worker (python/ray_worker.py)
- Why: Shows the pattern of using a message broker as a task source for Ray.
- What: Connects to RabbitMQ, consumes tasks from
ray_tasksqueue, dispatches each to a Ray remote function, publishes results toray_resultsqueue. - How: pika blocking consumer +
ray.get()per message.basic_qos(prefetch_count=4)limits in-flight messages.
3. Task Producer (python/ray_producer.py)
- Why: Publishes a batch of typed tasks to RabbitMQ and collects results.
- What: Publishes
computeandslowtask messages, then pollsray_resultsuntil all results arrive. - How: pika
basic_publishwithdelivery_mode=2(persistent).basic_getpolling loop for result collection.
π Lab Implementation & Scripts
python/ray_basics.py
"""
ray_basics.py β Core Ray concepts on a local cluster.
Demonstrates:
1. Remote functions and object references
2. Parallel fan-out and speedup
3. Actors β stateful remote objects
4. ray.wait() for non-blocking progress tracking
5. Dependent task pipelines
Does NOT require RabbitMQ. Runs standalone.
Run (from ray_learnings/python/):
python ray_basics.py
"""
import time
import ray
ray.init(num_cpus=4, ignore_reinit_error=True)
SEP = "β" * 60
def section(title: str) -> None:
print(f"\n{SEP}\n {title}\n{SEP}")
# ββ Remote Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@ray.remote
def add(x: int, y: int) -> int:
"""Remote addition β executes in a Ray worker process."""
return x + y
@ray.remote
def cpu_bound(n: int) -> int:
"""Simulates CPU-intensive work: sum of squares."""
return sum(i * i for i in range(n))
@ray.remote
def slow_task(job_id: str, duration: float = 0.5) -> dict:
"""Simulates a slow external call."""
time.sleep(duration)
return {"job_id": job_id, "status": "done", "duration": duration}
# ββ Stateful Actor βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@ray.remote
class Counter:
"""
A stateful Ray actor.
Actors run in their own process and serialise concurrent method calls.
Use actors for shared mutable state across distributed workers.
"""
def __init__(self, name: str) -> None:
self.name = name
self.count = 0
def increment(self, by: int = 1) -> int:
self.count += by
return self.count
def reset(self) -> None:
self.count = 0
def value(self) -> int:
return self.count
# ββ Demos ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def demo_basic_remote() -> None:
section("1. Basic Remote Function")
# .remote() dispatches immediately, returns an ObjectRef (future)
ref = add.remote(10, 20)
result = ray.get(ref) # blocks until result is available
print(f"add.remote(10, 20) = {result}")
# Multiple concurrent calls
refs = [add.remote(i, i) for i in range(5)]
results = ray.get(refs) # waits for all
print(f"add.remote(i, i) for i in 0..4 = {results}")
def demo_parallel_speedup() -> None:
section("2. Parallel Speedup")
n = 2_000_000
t0 = time.perf_counter()
refs = [cpu_bound.remote(n) for _ in range(4)]
ray.get(refs)
elapsed = time.perf_counter() - t0
print(f"4 Γ cpu_bound({n:,}) in parallel: {elapsed:.2f}s")
def demo_actor() -> None:
section("3. Actor β Stateful Remote Object")
counter = Counter.remote("demo-counter")
# All calls return ObjectRefs immediately (non-blocking)
increment_refs = [counter.increment.remote(i) for i in range(1, 6)]
values = ray.get(increment_refs)
print(f"Increments: {values}") # [1, 3, 6, 10, 15]
total = ray.get(counter.value.remote())
print(f"Total: {total}") # 15
ray.get(counter.reset.remote())
print(f"After reset: {ray.get(counter.value.remote())}") # 0
def demo_ray_wait() -> None:
section("4. ray.wait() β Process Results as They Arrive")
durations = [0.1, 0.4, 0.2, 0.6, 0.3]
refs = [slow_task.remote(f"job-{i}", d) for i, d in enumerate(durations)]
remaining = refs[:]
completed = 0
while remaining:
done, remaining = ray.wait(remaining, num_returns=1, timeout=2.0)
for ref in done:
result = ray.get(ref)
completed += 1
print(f" [{completed}/5] {result}")
def demo_pipeline() -> None:
section("5. Dependent Pipeline")
# add(3, 4) β cpu_bound(result) β slow_task(str(result))
ref1 = add.remote(3, 4) # 7
val1 = ray.get(ref1)
ref2 = cpu_bound.remote(val1)
val2 = ray.get(ref2)
ref3 = slow_task.remote(f"pipe-{val2}", duration=0.1)
final = ray.get(ref3)
print(f"Pipeline: add(3,4)βcpu_boundβslow_task = {final}")
if __name__ == "__main__":
print("Ray Basics β local cluster (4 CPUs)")
print("Dashboard: http://127.0.0.1:8265 (if enabled)\n")
demo_basic_remote()
demo_parallel_speedup()
demo_actor()
demo_ray_wait()
demo_pipeline()
print(f"\n{SEP}\n All demos complete.\n{SEP}")
ray.shutdown()
python/ray_worker.py
"""
ray_worker.py β RabbitMQ consumer that fans tasks out to Ray remote functions.
Architecture:
RabbitMQ (AMQP) βββΊ pika consumer βββΊ dispatch() βββΊ Ray remote functions
βββΊ result β AMQP result queue
The pika consumer handles queue I/O on the main thread.
Ray remote functions execute in parallel across available CPU cores.
ray.get() collects results synchronously before acking the AMQP message.
Prerequisites:
RabbitMQ running: docker compose up -d rabbitmq (in message_brokers/)
pip install -r requirements.txt
Run (Terminal 1, from ray_learnings/python/):
python ray_worker.py
Then (Terminal 2):
python ray_producer.py
"""
import json
import time
import pika
import ray
BROKER_HOST = "localhost"
TASK_QUEUE = "ray_tasks"
RESULT_QUEUE = "ray_results"
# ββ Ray Remote Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@ray.remote
def compute_task(n: int) -> int:
"""CPU-bound computation: sum of squares up to n."""
return sum(i * i for i in range(n))
@ray.remote
def slow_task(job_id: str, duration: float) -> dict:
"""Simulates a slow I/O-bound job."""
time.sleep(duration)
return {"job_id": job_id, "status": "done", "duration": duration}
# ββ AMQP Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def connect() -> tuple[pika.BlockingConnection, pika.channel.Channel]:
params = pika.ConnectionParameters(
host=BROKER_HOST,
credentials=pika.PlainCredentials("guest", "guest"),
heartbeat=60,
)
conn = pika.BlockingConnection(params)
ch = conn.channel()
ch.queue_declare(queue=TASK_QUEUE, durable=True)
ch.queue_declare(queue=RESULT_QUEUE, durable=True)
ch.basic_qos(prefetch_count=4) # up to 4 unacknowledged messages at once
return conn, ch
def publish_result(ch: pika.channel.Channel, task_id: str, payload: dict) -> None:
ch.basic_publish(
exchange="",
routing_key=RESULT_QUEUE,
body=json.dumps({"task_id": task_id, **payload}),
properties=pika.BasicProperties(
delivery_mode=2, # persistent
content_type="application/json",
),
)
# ββ Task Dispatch ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def dispatch(message: dict) -> ray.ObjectRef:
"""Route an AMQP message to the appropriate Ray remote function."""
task_type = message.get("type")
payload = message.get("payload", {})
if task_type == "compute":
return compute_task.remote(payload["n"])
elif task_type == "slow":
return slow_task.remote(payload["job_id"], payload["duration"])
else:
raise ValueError(f"Unknown task type: {task_type!r}")
# ββ Message Handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def on_message(
ch: pika.channel.Channel,
method: pika.spec.Basic.Deliver,
_props,
body: bytes,
) -> None:
message = json.loads(body)
task_id = message.get("id", "unknown")
print(f"[worker] Received {task_id}: type={message.get('type')}")
try:
ref = dispatch(message)
result = ray.get(ref, timeout=30)
print(f"[worker] Completed {task_id}: {result}")
publish_result(ch, task_id, {"result": result})
except Exception as exc:
print(f"[worker] Failed {task_id}: {exc}")
publish_result(ch, task_id, {"error": str(exc)})
ch.basic_ack(delivery_tag=method.delivery_tag)
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
print("Initialising Ray (local cluster, 4 CPUs)...")
ray.init(num_cpus=4, ignore_reinit_error=True)
print(f"Connecting to RabbitMQ at {BROKER_HOST}:5672...")
conn, ch = connect()
ch.basic_consume(queue=TASK_QUEUE, on_message_callback=on_message)
print(f"[worker] Listening on '{TASK_QUEUE}'. Press Ctrl+C to stop.\n")
try:
ch.start_consuming()
except KeyboardInterrupt:
print("\n[worker] Shutting down...")
ch.stop_consuming()
conn.close()
ray.shutdown()
python/ray_producer.py
"""
ray_producer.py β Publishes tasks to RabbitMQ for ray_worker.py to process.
Demonstrates the producer side of a Ray + RabbitMQ integration:
- Publishes typed task messages (compute, slow) to RabbitMQ
- Polls the result queue until all results are collected
- Shows task IDs mapped to results
Prerequisites:
RabbitMQ running: docker compose up -d rabbitmq (in message_brokers/)
Ray worker listening: python ray_worker.py (Terminal 1)
Run (Terminal 2, from ray_learnings/python/):
python ray_producer.py
"""
import json
import time
import uuid
import pika
BROKER_HOST = "localhost"
TASK_QUEUE = "ray_tasks"
RESULT_QUEUE = "ray_results"
SEP = "β" * 60
def section(title: str) -> None:
print(f"\n{SEP}\n {title}\n{SEP}")
# ββ AMQP Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def connect() -> tuple[pika.BlockingConnection, pika.channel.Channel]:
params = pika.ConnectionParameters(
host=BROKER_HOST,
credentials=pika.PlainCredentials("guest", "guest"),
heartbeat=60,
)
conn = pika.BlockingConnection(params)
ch = conn.channel()
ch.queue_declare(queue=TASK_QUEUE, durable=True)
ch.queue_declare(queue=RESULT_QUEUE, durable=True)
return conn, ch
def publish_task(ch: pika.channel.Channel, task_type: str, payload: dict) -> str:
task_id = str(uuid.uuid4())[:8]
message = json.dumps({"id": task_id, "type": task_type, "payload": payload})
ch.basic_publish(
exchange="",
routing_key=TASK_QUEUE,
body=message,
properties=pika.BasicProperties(
delivery_mode=2, # persistent β survives broker restart
content_type="application/json",
),
)
return task_id
def collect_results(
ch: pika.channel.Channel, expected: int, timeout: float = 30.0
) -> list[dict]:
results = []
deadline = time.time() + timeout
while len(results) < expected and time.time() < deadline:
method, _, body = ch.basic_get(queue=RESULT_QUEUE, auto_ack=True)
if method:
results.append(json.loads(body))
else:
time.sleep(0.1)
return results
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
print("Ray Producer β publishing tasks to RabbitMQ")
print("Ensure ray_worker.py is running in another terminal.\n")
conn, ch = connect()
all_task_ids: list[str] = []
# 1. Publish compute tasks
section("1. Publish 8 Compute Tasks")
for i in range(1, 9):
n = i * 200_000
tid = publish_task(ch, "compute", {"n": n})
all_task_ids.append(tid)
print(f" Published {tid}: compute(n={n:,})")
# 2. Publish slow tasks
section("2. Publish 4 Slow Tasks")
for i in range(4):
tid = publish_task(ch, "slow", {"job_id": f"job-{i}", "duration": 0.3})
all_task_ids.append(tid)
print(f" Published {tid}: slow_task(job_id=job-{i}, duration=0.3s)")
total = len(all_task_ids)
print(f"\nPublished {total} tasks total. Waiting for results...")
# 3. Collect results
section(f"3. Collecting {total} Results (timeout=30s)")
time.sleep(0.5) # brief pause to let worker start consuming
results = collect_results(ch, expected=total, timeout=30.0)
print(f"Received {len(results)}/{total} results:")
for r in sorted(results, key=lambda x: x.get("task_id", "")):
print(f" {r}")
conn.close()
print(f"\n{SEP}\n Producer done.\n{SEP}")