Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multislice support in ray #771

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions infra/cluster/job-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ cluster_name: levanter-cluster
# Configure GCP
provider:
type: gcp
region: us-central2
availability_zone: us-central2-b
region: us-west4
availability_zone: us-west4-a
project_id: hai-gcp-models

# Maximum Workers (excluding Head Node)
Expand Down Expand Up @@ -126,6 +126,32 @@ available_node_types:
schedulingConfig:
preemptible: true

tpu_slice_v5e_16:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-16
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

tpu_slice_v5e_256:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-256
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

docker:
image: "ghcr.io/stanford-crfm/levanter-cluster:latest"
container_name: "ray_docker"
Expand All @@ -140,7 +166,7 @@ docker:
- -v "/var/run/docker.sock:/var/run/docker.sock"

initialization_commands:
- yes | gcloud auth configure-docker us-central2-docker.pkg.dev
- yes | gcloud auth configure-docker us-west4-docker.pkg.dev
- "export TPU_WORKER_ID=$(curl -H 'Metadata-Flavor: Google' http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number) || true"
- which docker || (curl -fsSL https://get.docker.com -o get-docker.sh; sudo sh get-docker.sh; sudo usermod -aG docker $USER; sudo systemctl restart docker -f)
# always run this because ray doesn't run with sudo
Expand Down
3 changes: 2 additions & 1 deletion infra/launch_on_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"])
cli.add_arg(parser, config, ["--tpu_type"], required=True)
# TODO: bring node_count to Ray
# cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true")
cli.add_arg(parser, config, ["--retries"], default=10, type=int)
cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str)
Expand Down Expand Up @@ -122,6 +122,7 @@ def main():
env=env,
name="levanter",
retries=retries,
node_count=args.node_count,
)

address = args.address or os.getenv("RAY_ADDRESS")
Expand Down
5 changes: 5 additions & 0 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante
"/tmp:/tmp",
]

# optionally add multislice env vars (if set by ray runtime env vars)
for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]:
v = shlex.quote(str(v))
docker_command.extend(["-e", v])

for k, v in env.items():
v = shlex.quote(str(v))
k = shlex.quote(str(k))
Expand Down
188 changes: 182 additions & 6 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import multiprocessing
import os
import socket
import subprocess
import tempfile
import time
Expand Down Expand Up @@ -104,7 +105,83 @@ def do_run(remote_fn) -> _TpuRunResult:
return do_run.remote(remote_fn)


def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):
def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef:
"""
Run a remote function on multiple TPU slices.

Args:
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run

Returns:
A Ray ObjectRef that represents the result of the function
"""

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
class MultisliceActor:
def __init__(self):
self.pod_name = ray.util.accelerators.tpu.get_current_pod_name()
self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
self.ip = socket.gethostbyname(socket.gethostname())

def get_slice_info(self):
return self.pod_name, self.num_hosts, self.ip

def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:
port = 8081
mxla_env = {
"MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}",
"MEGASCALE_NUM_SLICES": str(num_slices),
"MEGASCALE_PORT": f"{port}",
"MEGASCALE_SLICE_ID": str(slice_id),
}

remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env)

info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
futures = [remote_fn.remote() for _ in range(self.num_hosts)]
try:
out = ray.get(futures)
logger.info("TPU job finished")
return TpuSuccess(info, out)
except RayError as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)
except Exception as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return TpuFailed(info, e)

actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore
info = _TpuInfo("get_slice_info", "ACTIVE", "TPU")
futures = [actor.get_slice_info.remote() for actor in actors]
try:
logger.info("Getting slice infos...")
# also act as a sync step
slice_infos = ray.get(futures)
logger.info(f"TPU slice infos {slice_infos}")
except RayError as e:
for actor in actors:
try:
ray.cancel(actor)
except Exception:
logger.exception("Failed to kill actor after primary failure")
return [_handle_ray_error(info, e)]

coordinator_ip = slice_infos[0][2]

return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)]


def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
"""
Redecorate a remote function to run on a TPU pod.

Expand All @@ -120,7 +197,10 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):

tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu
num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host})
remote_fn = remote_fn.options(
runtime_env=runtime_env,
resources={tpu_name: 1, "TPU": num_tpus_per_host},
)
logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host")
return remote_fn, tpu_name

Expand Down Expand Up @@ -193,11 +273,96 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
raise RuntimeError("Failed too many times") from problem


def run_on_pod_multislice_resumable(
remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10
):
"""
Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.

Args:
remote_fn: A remote function that takes no arguments
tpu_type: The type of TPU to run on, e.g. "v4-32"
num_slices: The number of slices to run
max_retries_preemption: The maximum number of times to retry if the job is preempted
max_retries_failure: The maximum number of times to retry if the job fails

Returns:
The result of the function (not an ObjectRef)

"""
num_failures = 0
num_preemptions = 0
attempt = 0
problem: Exception | None = None

while num_failures < max_retries_failure and num_preemptions < max_retries_preemption:
logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}")
attempt += 1
problem = None
try:
outs = ray.get(run_on_pod_multislice(remote_fn, tpu_type, num_slices))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e).lower():
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
else:
num_failures += 1
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue
except Exception as e:
problem = e
num_failures += 1
if num_failures >= max_retries_failure:
logger.exception("Failed too many times", exc_info=e)
raise e
else:
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue

if all(isinstance(out, TpuSuccess) for out in outs):
results = [out.result for out in outs]
logger.info("Success")
return results
elif any(isinstance(out, TpuPreempted) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuPreempted):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem)
elif any(isinstance(out, TpuFailed) for out in outs):
num_preemptions += 1
logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times")
elif any(isinstance(out, TpuRunError) for out in outs):
out = None
for o in outs:
if isinstance(o, TpuRunError):
out = o
assert out is not None
problem = out.error
num_preemptions += 1
problem = out.error
num_failures += 1
logger.warning(f"Failed {num_failures} times", exc_info=problem)
else:
raise RuntimeError(f"Unexpected result: {out}")

if num_preemptions >= max_retries_preemption:
raise RuntimeError("Preempted too many times") from problem
elif num_failures >= max_retries_failure:
raise RuntimeError("Failed too many times") from problem


def _run_command(*args, **kwargs):
return subprocess.check_call(args, **kwargs)


def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10):
def run_docker_on_pod(
image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10
):
env = _massage_env(env)

docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name)
Expand All @@ -210,9 +375,18 @@ def run_docker():
logger.exception("Failed to run docker command")
raise e

run_on_pod_resumable(
ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
if num_slices == 1:
run_on_pod_resumable(
ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
)
else:
run_on_pod_multislice_resumable(
ray.remote(run_docker),
tpu_type=tpu_type,
num_slices=num_slices,
max_retries_failure=retries,
max_retries_preemption=10000,
)


def _kill_old_container(name):
Expand Down Expand Up @@ -351,6 +525,7 @@ class RunDockerOnPodConfig:
env: dict = dataclasses.field(default_factory=dict)
name: str = "levanter"
retries: int = 10
node_count: int = 1


def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None):
Expand Down Expand Up @@ -419,6 +594,7 @@ def main(args: RunDockerOnPodConfig):
tpu_type=args.tpu_type,
env=args.env,
name=args.name,
num_slices=args.node_count,
)


Expand Down
Loading