Skip to content

Commit

Permalink
[RLlib] Cleanup examples folder vol 32: Enable RLlib + Serve example …
Browse files Browse the repository at this point in the history
…in CI and translate to new API stack. (ray-project#48687)
  • Loading branch information
sven1977 authored Nov 12, 2024
1 parent 932919e commit c47bd45
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 77 deletions.
18 changes: 9 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2973,15 +2973,15 @@ py_test(

# subdirectory: ray_serve/
# ....................................
# TODO (sven): Uncomment once the problem with the path on BAZEL is solved.
# py_test(
# name = "examples/ray_serve/ray_serve_with_rllib",
# main = "examples/ray_serve/ray_serve_with_rllib.py",
# tags = ["team:rllib", "exclusive", "examples"],
# size = "medium",
# srcs = ["examples/ray_serve/ray_serve_with_rllib.py"],
# args = ["--train-iters=2", "--serve-episodes=2", "--no-render"]
# )
py_test(
name = "examples/ray_serve/ray_serve_with_rllib",
main = "examples/ray_serve/ray_serve_with_rllib.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/ray_serve/ray_serve_with_rllib.py"],
data = glob(["examples/ray_serve/classes/**"]),
args = ["--stop-iters=2", "--num-episodes-served=2", "--no-render", "--port=12345"]
)

# subdirectory: ray_tune/
# ....................................
Expand Down
30 changes: 19 additions & 11 deletions rllib/examples/ray_serve/classes/cartpole_deployment.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import json
from typing import Dict

import numpy as np
from starlette.requests import Request
import torch

from ray import serve
from ray.rllib.core import Columns
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.serve.schema import LoggingConfig
from ray.rllib.algorithms.algorithm import Algorithm


@serve.deployment(
route_prefix="/rllib-rlmodule",
logging_config=LoggingConfig(log_level="WARN"),
)
@serve.deployment(logging_config=LoggingConfig(log_level="WARN"))
class ServeRLlibRLModule:
"""Callable class used by Ray Serve to handle async requests.
Expand All @@ -21,22 +21,30 @@ class ServeRLlibRLModule:
(with a current observation).
"""

def __init__(self, checkpoint):
self.algo = Algorithm.from_checkpoint(checkpoint)
def __init__(self, rl_module_checkpoint):
self.rl_module = RLModule.from_checkpoint(rl_module_checkpoint)

async def __call__(self, starlette_request: Request) -> Dict:
request = await starlette_request.body()
request = request.decode("utf-8")
request = json.loads(request)
obs = request["observation"]

# Compute and return the action for the given observation.
action = self.algo.compute_single_action(obs)
# Compute and return the action for the given observation (create a batch
# with B=1 and convert to torch).
output = self.rl_module.forward_inference(
batch={"obs": torch.from_numpy(np.array([obs], np.float32))}
)
# Extract action logits and unbatch.
logits = output[Columns.ACTION_DIST_INPUTS][0]
# Act greedily (argmax).
action = int(np.argmax(logits))

return {"action": int(action)}
return {"action": action}


# Defining the builder function. This is so we can start our deployment via:
# `serve run [this py module]:rl_module checkpoint=[some algo checkpoint path]`
def rl_module(args: Dict[str, str]):
return ServeRLlibRLModule.bind(args["checkpoint"])
serve.start(http_options={"host": "0.0.0.0", "port": args.get("port", 12345)})
return ServeRLlibRLModule.bind(args["rl_module_checkpoint"])
170 changes: 113 additions & 57 deletions rllib/examples/ray_serve/ray_serve_with_rllib.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,105 @@
"""This example script shows how one can use Ray Serve in combination with RLlib.
"""Example on how to run RLlib in combination with Ray Serve.
Here, we serve an already trained PyTorch RLModule to provide action computations
to a Ray Serve client.
"""
import argparse
import atexit
import os
This example trains an agent with PPO on the CartPole environment, then creates
an RLModule checkpoint and returns its location. After that, it sends the checkpoint
to the Serve deployment for serving the trained RLModule (policy).
import requests
import subprocess
import time
This example:
- shows how to set up a Ray Serve deployment for serving an already trained
RLModule (policy network).
- shows how to request new actions from the Ray Serve deployment while actually
running through episodes in an environment (on which the RLModule that's served
was trained).
import gymnasium as gym
from pathlib import Path
import ray
from ray.rllib.algorithms.algorithm import AlgorithmConfig
from ray.rllib.algorithms.ppo import PPOConfig
How to run this script
----------------------
`python [script file name].py --enable-new-api-stack --stop-reward=200.0`
parser = argparse.ArgumentParser()
parser.add_argument("--train-iters", type=int, default=3)
parser.add_argument("--serve-episodes", type=int, default=2)
parser.add_argument("--no-render", action="store_true")
Use the `--stop-iters`, `--stop-reward`, and/or `--stop-timesteps` options to
determine how long to train the policy for. Use the `--serve-episodes` option to
set the number of episodes to serve (after training) and the `--no-render` option
to NOT render the environment during the serving phase.
For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.
def train_rllib_rl_module(config: AlgorithmConfig, train_iters: int = 1):
"""Trains a PPO (RLModule) on ALE/MsPacman-v5 for n iterations.
For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
Saves the trained Algorithm to disk and returns the checkpoint path.
You can visualize experiment results in ~/ray_results using TensorBoard.
Args:
config: The algo config object for the Algorithm.
train_iters: For how many iterations to train the Algorithm.
Returns:
str: The saved checkpoint to restore the RLModule from.
"""
# Create algorithm from config.
algo = config.build()
Results to expect
-----------------
# Train for n iterations, then save, stop, and return the checkpoint path.
for _ in range(train_iters):
print(algo.train())
You should see something similar to the following on the command line when using the
options: `--stop-reward=250.0`, `--num-episodes-served=2`, and `--port=12345`:
# TODO (sven): Change this example to only storing the RLModule checkpoint, NOT
# the entire Algorithm.
checkpoint_result = algo.save()
[First, the RLModule is trained through PPO]
algo.stop()
+-----------------------------+------------+-----------------+--------+
| Trial name | status | loc | iter |
| | | | |
|-----------------------------+------------+-----------------+--------+
| PPO_CartPole-v1_84778_00000 | TERMINATED | 127.0.0.1:40411 | 1 |
+-----------------------------+------------+-----------------+--------+
+------------------+---------------------+------------------------+
| total time (s) | episode_return_mean | num_env_steps_sample |
| | | d_lifetime |
|------------------+---------------------|------------------------|
| 2.87052 | 253.2 | 12000 |
+------------------+---------------------+------------------------+
return checkpoint_result.checkpoint
[The RLModule is deployed through Ray Serve on port 12345]
Started Ray Serve with PID: 40458
[A few episodes are played through using the policy service (w/ greedy, non-exploratory
actions)]
Episode R=500.0
Episode R=500.0
"""

import atexit
import os

import requests
import subprocess
import time

import gymnasium as gym
from pathlib import Path

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import (
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
DEFAULT_MODULE_ID,
)
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
)
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)

parser = add_rllib_example_script_args()
parser.set_defaults(
enable_new_api_stack=True,
checkpoint_freq=1,
checkpoint_at_and=True,
)
parser.add_argument("--num-episodes-served", type=int, default=2)
parser.add_argument("--no-render", action="store_true")
parser.add_argument("--port", type=int, default=12345)


def kill_proc(proc):
Expand All @@ -64,18 +114,23 @@ def kill_proc(proc):
if __name__ == "__main__":
args = parser.parse_args()

ray.init(num_cpus=8)

# Config for the served RLlib RLModule/Algorithm.
config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True)
.environment("CartPole-v1")
base_config = PPOConfig().environment("CartPole-v1")

results = run_rllib_example_script_experiment(base_config, args)
algo_checkpoint = results.get_best_result(
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
).checkpoint.path
# We only need the RLModule component from the algorithm checkpoint. It's located
# under "[algo checkpoint dir]/learner_group/learner/rl_module/[default policy ID]
rl_module_checkpoint = (
Path(algo_checkpoint)
/ COMPONENT_LEARNER_GROUP
/ COMPONENT_LEARNER
/ COMPONENT_RL_MODULE
/ DEFAULT_MODULE_ID
)

# Train the Algorithm for some time, then save it and get the checkpoint path.
checkpoint = train_rllib_rl_module(config, train_iters=args.train_iters)

path_of_this_file = Path(__file__).parent
os.chdir(path_of_this_file)
# Start the serve app with the trained checkpoint.
Expand All @@ -84,7 +139,9 @@ def kill_proc(proc):
"serve",
"run",
"classes.cartpole_deployment:rl_module",
f"checkpoint={checkpoint.path}",
f"rl_module_checkpoint={rl_module_checkpoint}",
f"port={args.port}",
"route_prefix=/rllib-rlmodule",
]
)
# Register our `kill_proc` function to be called on exit to stop Ray Serve again.
Expand All @@ -97,35 +154,34 @@ def kill_proc(proc):
# Create the environment that we would like to receive
# served actions for.
env = gym.make("CartPole-v1", render_mode="human")
obs, info = env.reset()
obs, _ = env.reset()

num_episodes = 0
episode_return = 0.0

while num_episodes < args.serve_episodes:
while num_episodes < args.num_episodes_served:
# Render env if necessary.
if not args.no_render:
env.render()

# print("-> Requesting action for obs ...")
# print(f"-> Requesting action for obs={obs} ...", end="")
# Send a request to serve.
resp = requests.get(
"http://localhost:8000/rllib-rlmodule",
f"http://localhost:{args.port}/rllib-rlmodule",
json={"observation": obs.tolist()},
# timeout=5.0,
)
response = resp.json()
# print("<- Received response {}".format(response))
# print(f" received: action={response['action']}")

# Apply the action in the env.
action = response["action"]
obs, reward, done, _, _ = env.step(action)
obs, reward, terminated, truncated, _ = env.step(action)
episode_return += reward

# If episode done -> reset to get initial observation of new episode.
if done:
if terminated or truncated:
print(f"Episode R={episode_return}")
obs, info = env.reset()
obs, _ = env.reset()
num_episodes += 1
episode_return = 0.0

Expand Down

0 comments on commit c47bd45

Please sign in to comment.