diff --git a/.vale/styles/config/vocabularies/General/accept.txt b/.vale/styles/config/vocabularies/General/accept.txt index df704e70a2721..9ffc6965e08cd 100644 --- a/.vale/styles/config/vocabularies/General/accept.txt +++ b/.vale/styles/config/vocabularies/General/accept.txt @@ -17,3 +17,5 @@ GKE namespace ARM breakpoint +deduplicate[s] +deduplication diff --git a/BUILD.bazel b/BUILD.bazel index f023efc65e810..7db58a17a2a79 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -93,8 +93,10 @@ ray_cc_library( "src/ray/rpc/common.cc", "src/ray/rpc/grpc_server.cc", "src/ray/rpc/server_call.cc", + "src/ray/rpc/rpc_chaos.cc", ], hdrs = glob([ + "src/ray/rpc/rpc_chaos.h", "src/ray/rpc/client_call.h", "src/ray/rpc/common.h", "src/ray/rpc/grpc_client.h", @@ -514,6 +516,7 @@ ray_cc_library( "@boost//:bimap", "@com_github_grpc_grpc//src/proto/grpc/health/v1:health_proto", "@com_google_absl//absl/container:btree", + "//src/ray/util:thread_checker", ], ) @@ -1551,6 +1554,19 @@ ray_cc_test( ], ) +ray_cc_test( + name = "rpc_chaos_test", + size = "small", + srcs = [ + "src/ray/rpc/test/rpc_chaos_test.cc", + ], + tags = ["team:core"], + deps = [ + ":grpc_common_lib", + "@com_google_googletest//:gtest_main", + ], +) + ray_cc_test( name = "core_worker_client_pool_test", size = "small", diff --git a/doc/source/cluster/kubernetes/images/yunikorn-dashboard-apps-pending.png b/doc/source/cluster/kubernetes/images/yunikorn-dashboard-apps-pending.png new file mode 100644 index 0000000000000..e5b56b483af0f Binary files /dev/null and b/doc/source/cluster/kubernetes/images/yunikorn-dashboard-apps-pending.png differ diff --git a/doc/source/cluster/kubernetes/images/yunikorn-dashboard-apps-running.png b/doc/source/cluster/kubernetes/images/yunikorn-dashboard-apps-running.png new file mode 100644 index 0000000000000..11d626f2347c7 Binary files /dev/null and b/doc/source/cluster/kubernetes/images/yunikorn-dashboard-apps-running.png differ diff --git a/doc/source/cluster/kubernetes/k8s-ecosystem.md b/doc/source/cluster/kubernetes/k8s-ecosystem.md index 418deaa167fb0..7e7eb221b9512 100644 --- a/doc/source/cluster/kubernetes/k8s-ecosystem.md +++ b/doc/source/cluster/kubernetes/k8s-ecosystem.md @@ -9,6 +9,7 @@ k8s-ecosystem/ingress k8s-ecosystem/prometheus-grafana k8s-ecosystem/pyspy k8s-ecosystem/volcano +k8s-ecosystem/yunikorn k8s-ecosystem/kubeflow k8s-ecosystem/kueue k8s-ecosystem/istio @@ -18,6 +19,7 @@ k8s-ecosystem/istio * {ref}`kuberay-prometheus-grafana` * {ref}`kuberay-pyspy-integration` * {ref}`kuberay-volcano` +* {ref}`kuberay-yunikorn` * {ref}`kuberay-kubeflow-integration` * {ref}`kuberay-kueue` * {ref}`kuberay-istio` diff --git a/doc/source/cluster/kubernetes/k8s-ecosystem/yunikorn.md b/doc/source/cluster/kubernetes/k8s-ecosystem/yunikorn.md new file mode 100644 index 0000000000000..ad48ecc5d4759 --- /dev/null +++ b/doc/source/cluster/kubernetes/k8s-ecosystem/yunikorn.md @@ -0,0 +1,190 @@ +(kuberay-yunikorn)= + +# KubeRay integration with Apache YuniKorn + +[Apache YuniKorn](https://yunikorn.apache.org/) is a light-weight, universal resource scheduler for container orchestrator systems. It performs fine-grained resource sharing for various workloads efficiently on a large scale, multi-tenant, and cloud-native environment. YuniKorn brings a unified, cross-platform, scheduling experience for mixed workloads that consist of stateless batch workloads and stateful services. + +KubeRay's Apache YuniKorn integration enables more efficient scheduling of Ray Pods in multi-tenant Kubernetes environments. + +:::{note} + +This feature requires KubeRay version 1.2.2 or newer, and it's in alpha testing. + +::: + +## Step 1: Create a Kubernetes cluster with KinD +Run the following command in a terminal: + +```shell +kind create cluster +``` + +## Step 2: Install Apache YuniKorn + +You need to successfully install Apache YuniKorn on your Kubernetes cluster before enabling Apache YuniKorn integration with KubeRay. +See [Get Started](https://yunikorn.apache.org/docs/) for Apache YuniKorn installation instructions. + +## Step 3: Install the KubeRay operator with Apache YuniKorn support + +When installing KubeRay operator using Helm, pass the `--set batchScheduler.name=yunikorn` flag at the command line: + +```shell +helm install kuberay-operator kuberay/kuberay-operator --version 1.2.2 --set batchScheduler.name=yunikorn +``` + +## Step 4: Use Apache YuniKorn for gang scheduling + +This example uses gang scheduling with Apache YuniKorn and KubeRay. + +First, create a queue with a capacity of 4 CPUs and 6Gi of RAM by editing the ConfigMap: + +Run `kubectl edit configmap -n yunikorn yunikorn-defaults` + +Helm creates this ConfigMap during the installation of the Apache YuniKorn Helm chart. + +Add a `queues.yaml` config under the `data` key. The `ConfigMap` should look like the following: + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + # Metadata for the ConfigMap, skip for brevity. +data: + queues.yaml: | + partitions: + - name: default + queues: + - name: root + queues: + - name: test + submitacl: "*" + parent: false + resources: + guaranteed: + memory: 6G + vcore: 4 + max: + memory: 6G + vcore: 4 +``` + +Save the changes and exit the editor. This configuration creates a queue named `root.test` with a capacity of 4 CPUs and 6Gi of RAM. + +Next, create a RayCluster with a head node with 1 CPU and 2GiB of RAM, and two workers with 1 CPU and 1GiB of RAM each, for a total of 3 CPU and 4GiB of RAM: + +```shell +# Path: kuberay/ray-operator/config/samples +# Configure the necessary labels on the RayCluster custom resource for Apache YuniKorn scheduler's gang scheduling: +# - `ray.io/gang-scheduling-enabled`: Set to `true` to enable gang scheduling. +# - `yunikorn.apache.org/app-id`: Set to a unique identifier for the application in Kubernetes, even across different namespaces. +# - `yunikorn.apache.org/queue`: Set to the name of one of the queues in Apache YuniKorn. +wget https://raw.githubusercontent.com/ray-project/kuberay/master/ray-operator/config/samples/ray-cluster.yunikorn-scheduler.yaml +kubectl apply -f ray-cluster.yunikorn-scheduler.yaml +``` + +Check the RayCluster that the KubeRay operator created: + +```shell +$ kubectl describe raycluster test-yunikorn-0 + +Name: test-yunikorn-0 +Namespace: default +Labels: ray.io/gang-scheduling-enabled=true + yunikorn.apache.org/app-id=test-yunikorn-0 + yunikorn.apache.org/queue=root.test +Annotations: +API Version: ray.io/v1 +Kind: RayCluster +Metadata: + Creation Timestamp: 2024-09-29T09:52:30Z + Generation: 1 + Resource Version: 951 + UID: cae1dbc9-5a67-4b43-b0d9-be595f21ab85 +# Other fields are skipped for brevity +```` + +Note the labels on the RayCluster: `ray.io/gang-scheduling-enabled=true`, `yunikorn.apache.org/app-id=test-yunikorn-0`, and `yunikorn.apache.org/queue=root.test`. + +:::{note} + +You only need the `ray.io/gang-scheduling-enabled` label when you require gang scheduling. If you don't set this label, YuniKorn schedules the Ray cluster without enforcing gang scheduling. + +::: + +Because the queue has a capacity of 4 CPU and 6GiB of RAM, this resource should schedule successfully without any issues. + +```shell +$ kubectl get pods + +NAME READY STATUS RESTARTS AGE +test-yunikorn-0-head-98fmp 1/1 Running 0 67s +test-yunikorn-0-worker-worker-42tgg 1/1 Running 0 67s +test-yunikorn-0-worker-worker-467mn 1/1 Running 0 67s +``` + +Verify the scheduling by checking the [Apache YuniKorn dashboard](https://yunikorn.apache.org/docs/#access-the-web-ui). + +```shell +kubectl port-forward svc/yunikorn-service 9889:9889 -n yunikorn +``` + +Go to `http://localhost:9889/#/applications` to see the running apps. + +![Apache YuniKorn dashboard](../images/yunikorn-dashboard-apps-running.png) + +Next, add an additional RayCluster with the same configuration of head and worker nodes, but with a different name: + +```shell +# Replace the name with `test-yunikorn-1`. +sed 's/test-yunikorn-0/test-yunikorn-1/' ray-cluster.yunikorn-scheduler.yaml | kubectl apply -f- +``` + +Now all the Pods for `test-yunikorn-1` are in the `Pending` state: + +```shell +$ kubectl get pods + +NAME READY STATUS RESTARTS AGE +test-yunikorn-0-head-98fmp 1/1 Running 0 4m22s +test-yunikorn-0-worker-worker-42tgg 1/1 Running 0 4m22s +test-yunikorn-0-worker-worker-467mn 1/1 Running 0 4m22s +test-yunikorn-1-head-xl2r5 0/1 Pending 0 71s +test-yunikorn-1-worker-worker-l6ttz 0/1 Pending 0 71s +test-yunikorn-1-worker-worker-vjsts 0/1 Pending 0 71s +tg-test-yunikorn-1-headgroup-vgzvoot0dh 0/1 Pending 0 69s +tg-test-yunikorn-1-worker-eyti2bn2jv 1/1 Running 0 69s +tg-test-yunikorn-1-worker-k8it0x6s73 0/1 Pending 0 69s +``` + +Apache YuniKorn creates the Pods with the `tg-` prefix for gang scheduling purpose. + +Go to `http://localhost:9889/#/applications` and to see `test-yunikorn-1` in the `Accepted` state but not running yet: + +![Apache YuniKorn dashboard](../images/yunikorn-dashboard-apps-pending.png) + +Because the new cluster requires more CPU and RAM than the queue allows, even though one of the Pods would fit in the remaining 1 CPU and 2GiB of RAM, Apache YuniKorn doesn't place the cluster's Pods until there's enough room for all of the Pods. Without using Apache YuniKorn for gang scheduling in this way, KubeRay would place one of the Pods, and only partially allocating the cluster. + +Delete the first RayCluster to free up resources in the queue: + +```shell +kubectl delete raycluster test-yunikorn-0 +``` + +Now all the Pods for the second cluster change to the `Running` state, because enough resources are now available to schedule the entire set of Pods: + +Check the Pods again to see that the second cluster is now up and running: + +```shell +$ kubectl get pods + +NAME READY STATUS RESTARTS AGE +test-yunikorn-1-head-xl2r5 1/1 Running 0 3m34s +test-yunikorn-1-worker-worker-l6ttz 1/1 Running 0 3m34s +test-yunikorn-1-worker-worker-vjsts 1/1 Running 0 3m34s +``` + +Clean up the resources: + +```shell +kubectl delete raycluster test-yunikorn-1 +``` diff --git a/doc/source/ray-contribute/development.rst b/doc/source/ray-contribute/development.rst index fafa770675f69..1e699d2133cb4 100644 --- a/doc/source/ray-contribute/development.rst +++ b/doc/source/ray-contribute/development.rst @@ -92,8 +92,8 @@ RLlib, Tune, Autoscaler, and most Python files do not require you to build and c .. code-block:: shell - # For example, for Python 3.8: - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl + # For example, for Python 3.9: + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl 4. Replace Python files in the installed package with your local editable copy. We provide a simple script to help you do this: ``python python/ray/setup-dev.py``. Running the script will remove the ``ray/tune``, ``ray/rllib``, ``ray/autoscaler`` dir (among other directories) bundled with the ``ray`` pip package, and replace them with links to your local code. This way, changing files in your git clone will directly affect the behavior of your installed Ray. diff --git a/doc/source/ray-core/handling-dependencies.rst b/doc/source/ray-core/handling-dependencies.rst index 8a2ac81edfc2b..8dfc883232fdf 100644 --- a/doc/source/ray-core/handling-dependencies.rst +++ b/doc/source/ray-core/handling-dependencies.rst @@ -37,7 +37,7 @@ Concepts Preparing an environment using the Ray Cluster launcher ------------------------------------------------------- -The first way to set up dependencies is to is to prepare a single environment across the cluster before starting the Ray runtime. +The first way to set up dependencies is to prepare a single environment across the cluster before starting the Ray runtime. - You can build all your files and dependencies into a container image and specify this in your your :ref:`Cluster YAML Configuration `. @@ -327,9 +327,7 @@ To ensure your local changes show up across all Ray workers and can be imported # No need to import my_module inside this function. my_module.test() - ray.get(f.remote()) - -Note: This feature is currently limited to modules that are packages with a single directory containing an ``__init__.py`` file. For single-file modules, you may use ``working_dir``. + ray.get(test_my_module.remote()) .. _runtime-environments-api-ref: @@ -358,13 +356,15 @@ The ``runtime_env`` is a Python dictionary or a Python class :class:`ray.runtime Note: If the local directory contains symbolic links, Ray follows the links and the files they point to are uploaded to the cluster. - ``py_modules`` (List[str|module]): Specifies Python modules to be available for import in the Ray workers. (For more ways to specify packages, see also the ``pip`` and ``conda`` fields below.) - Each entry must be either (1) a path to a local directory, (2) a URI to a remote zip or wheel file (see :ref:`remote-uris` for details), (3) a Python module object, or (4) a path to a local `.whl` file. + Each entry must be either (1) a path to a local file or directory, (2) a URI to a remote zip or wheel file (see :ref:`remote-uris` for details), (3) a Python module object, or (4) a path to a local `.whl` file. - Examples of entries in the list: - ``"."`` - - ``"/local_dependency/my_module"`` + - ``"/local_dependency/my_dir_module"`` + + - ``"/local_dependency/my_file_module.py"`` - ``"s3://bucket/my_module.zip"`` @@ -380,8 +380,6 @@ The ``runtime_env`` is a Python dictionary or a Python class :class:`ray.runtime Note: For option (1), if the local directory contains a ``.gitignore`` file, the files and paths specified there are not uploaded to the cluster. You can disable this by setting the environment variable `RAY_RUNTIME_ENV_IGNORE_GITIGNORE=1` on the machine doing the uploading. - Note: This feature is currently limited to modules that are packages with a single directory containing an ``__init__.py`` file. For single-file modules, you may use ``working_dir``. - - ``excludes`` (List[str]): When used with ``working_dir`` or ``py_modules``, specifies a list of files or paths to exclude from being uploaded to the cluster. This field uses the pattern-matching syntax used by ``.gitignore`` files: see ``_ for details. Note: In accordance with ``.gitignore`` syntax, if there is a separator (``/``) at the beginning or middle (or both) of the pattern, then the pattern is interpreted relative to the level of the ``working_dir``. diff --git a/doc/source/ray-observability/user-guides/configure-logging.md b/doc/source/ray-observability/user-guides/configure-logging.md index 73c6102ab46ad..3be1af34cbffb 100644 --- a/doc/source/ray-observability/user-guides/configure-logging.md +++ b/doc/source/ray-observability/user-guides/configure-logging.md @@ -62,7 +62,7 @@ System logs may include information about your applications. For example, ``runt This is the log file of the agent containing logs of create or delete requests and cache hits and misses. For the logs of the actual installations (for example, ``pip install`` logs), see the ``runtime_env_setup-[job_id].log`` file (see below). - ``runtime_env_setup-ray_client_server_[port].log``: Logs from installing {ref}`Runtime Environments ` for a job when connecting with {ref}`Ray Client `. -- ``runtime_env_setup-[job_id].log``: Logs from installing {ref}`Runtime Environments ` for a Task, Actor or Job. This file is only present if a Runtime Environment is installed. +- ``runtime_env_setup-[job_id].log``: Logs from installing {ref}`runtime environments ` for a Task, Actor, or Job. This file is only present if you install a runtime environment. (log-redirection-to-driver)= @@ -136,13 +136,58 @@ The output is as follows: (task pid=534174) Hello there, I am a task 0.17536720316370757 [repeated 99x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication) ``` -This feature is especially useful when importing libraries such as `tensorflow` or `numpy`, which may emit many verbose warning messages when imported. Configure this feature as follows: - -1. Set ``RAY_DEDUP_LOGS=0`` to disable this feature entirely. -2. Set ``RAY_DEDUP_LOGS_AGG_WINDOW_S=`` to change the agggregation window. -3. Set ``RAY_DEDUP_LOGS_ALLOW_REGEX=`` to specify log messages to never deduplicate. -4. Set ``RAY_DEDUP_LOGS_SKIP_REGEX=`` to specify log messages to skip printing. - +This feature is useful when importing libraries such as `tensorflow` or `numpy`, which may emit many verbose warning messages when you import them. + +Configure the following environment variables on the driver process **before importing Ray** to customize log deduplication: + +* Set ``RAY_DEDUP_LOGS=0`` to turn off this feature entirely. +* Set ``RAY_DEDUP_LOGS_AGG_WINDOW_S=`` to change the aggregation window. +* Set ``RAY_DEDUP_LOGS_ALLOW_REGEX=`` to specify log messages to never deduplicate. + * Example: + ```python + import os + os.environ["RAY_DEDUP_LOGS_ALLOW_REGEX"] = "ABC" + + import ray + + @ray.remote + def f(): + print("ABC") + print("DEF") + + ray.init() + ray.get([f.remote() for _ in range(5)]) + + # 2024-10-10 17:54:19,095 INFO worker.py:1614 -- Connecting to existing Ray cluster at address: 172.31.13.10:6379... + # 2024-10-10 17:54:19,102 INFO worker.py:1790 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 + # (f pid=1574323) ABC + # (f pid=1574323) DEF + # (f pid=1574321) ABC + # (f pid=1574318) ABC + # (f pid=1574320) ABC + # (f pid=1574322) ABC + # (f pid=1574322) DEF [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.) + ``` +* Set ``RAY_DEDUP_LOGS_SKIP_REGEX=`` to specify log messages to skip printing. + * Example: + ```python + import os + os.environ["RAY_DEDUP_LOGS_SKIP_REGEX"] = "ABC" + + import ray + + @ray.remote + def f(): + print("ABC") + print("DEF") + + ray.init() + ray.get([f.remote() for _ in range(5)]) + # 2024-10-10 17:55:05,308 INFO worker.py:1614 -- Connecting to existing Ray cluster at address: 172.31.13.10:6379... + # 2024-10-10 17:55:05,314 INFO worker.py:1790 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 + # (f pid=1574317) DEF + # (f pid=1575229) DEF [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.) + ``` ## Distributed progress bars (tqdm) diff --git a/doc/source/rllib/rllib-examples.rst b/doc/source/rllib/rllib-examples.rst index 6457ebd171871..5a2c4dca69f63 100644 --- a/doc/source/rllib/rllib-examples.rst +++ b/doc/source/rllib/rllib-examples.rst @@ -254,12 +254,8 @@ RLModules - |old_stack| `How to using the "Repeated" space of RLlib for variable lengths observations `__: How to use RLlib's `Repeated` space to handle variable length observations. - |old_stack| `How to write a custom Keras model `__: - Example of using a custom Keras model. -- |old_stack| `How to register a custom model with supervised loss `__: Example of defining and registering a custom model with a supervised loss. - |old_stack| `How to train with batch normalization `__: - Example of adding batch norm layers to a custom model. -- |old_stack| `How to write a custom model with its custom API `__: Shows how to define a custom Model API in RLlib, such that it can be used inside certain algorithms. - |old_stack| `How to write a "trajectory ciew API" utilizing model `__: An example on how a model can use the trajectory view API to specify its own input. diff --git a/doc/source/rllib/rllib-models.rst b/doc/source/rllib/rllib-models.rst index 5e3badd3b8e3f..717c6bb196c6e 100644 --- a/doc/source/rllib/rllib-models.rst +++ b/doc/source/rllib/rllib-models.rst @@ -364,59 +364,7 @@ calculating head on top of your policy model. In order to expand a Model's API, define and implement a new method (e.g. ``get_q_values()``) in your TF- or TorchModelV2 sub-class. You can now wrap this new API either around RLlib's default models or around -your custom (``forward()``-overriding) model classes. Here are two examples that illustrate how to do this: - -**The Q-head API: Adding a dueling layer on top of a default RLlib model**. - -The following code adds a ``get_q_values()`` method to the automatically chosen -default Model (e.g. a ``FullyConnectedNetwork`` if the observation space is a 1D Box -or Discrete): - -.. literalinclude:: ../../../rllib/examples/_old_api_stack/models/custom_model_api.py - :language: python - :start-after: __sphinx_doc_model_api_1_begin__ - :end-before: __sphinx_doc_model_api_1_end__ - -Now, for your algorithm that needs to have this model API to work properly (e.g. DQN), -you use this following code to construct the complete final Model using the -``ModelCatalog.get_model_v2`` factory function (`code here `__): - -.. literalinclude:: ../../../rllib/examples/custom_model_api.py - :language: python - :start-after: __sphinx_doc_model_construct_1_begin__ - :end-before: __sphinx_doc_model_construct_1_end__ - -With the model object constructed above, you can get the underlying intermediate output (before the dueling head) -by calling ``my_dueling_model`` directly (``out = my_dueling_model([input_dict])``), and then passing ``out`` into -your custom ``get_q_values`` method: ``q_values = my_dueling_model.get_q_values(out)``. - - -**The single Q-value API for SAC**. - -Our DQN model from above takes an observation and outputs one Q-value per (discrete) action. -Continuous SAC - on the other hand - uses Models that calculate one Q-value only -for a single (**continuous**) action, given an observation and that particular action. - -Let's take a look at how we would construct this API and wrap it around a custom model: - -.. literalinclude:: ../../../rllib/examples/_old_api_stack/models/custom_model_api.py - :language: python - :start-after: __sphinx_doc_model_api_2_begin__ - :end-before: __sphinx_doc_model_api_2_end__ - -Now, for your algorithm that needs to have this model API to work properly (e.g. SAC), -you use this following code to construct the complete final Model using the -``ModelCatalog.get_model_v2`` factory function (`code here `__): - -.. literalinclude:: ../../../rllib/examples/custom_model_api.py - :language: python - :start-after: __sphinx_doc_model_construct_2_begin__ - :end-before: __sphinx_doc_model_construct_2_end__ - -With the model object constructed above, you can get the underlying intermediate output (before the q-head) -by calling ``my_cont_action_q_model`` directly (``out = my_cont_action_q_model([input_dict])``), and then passing ``out`` -and some action into your custom ``get_single_q_value`` method: -``q_value = my_cont_action_q_model.get_signle_q_value(out, action)``. +your custom (``forward()``-overriding) model classes. More examples for Building Custom Models @@ -505,7 +453,7 @@ Supervised Model Losses You can mix supervised losses into any RLlib algorithm through custom models. For example, you can add an imitation learning loss on expert experiences, or a self-supervised autoencoder loss within the model. These losses can be defined over either policy evaluation inputs, or data read from `offline storage `__. -**TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``metrics()`` method. Here is a `runnable example `__ of adding an imitation loss to CartPole training that is defined over a `offline dataset `__. +**TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``metrics()`` method. **PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass. diff --git a/python/ray/_private/ray_logging/logging_config.py b/python/ray/_private/ray_logging/logging_config.py index b944625aa8810..6788571f9e6aa 100644 --- a/python/ray/_private/ray_logging/logging_config.py +++ b/python/ray/_private/ray_logging/logging_config.py @@ -56,6 +56,13 @@ def __init__(self): "level": log_level, "handlers": ["console"], }, + "loggers": { + "ray": { + "level": log_level, + "handlers": ["console"], + "propagate": False, + } + }, } } diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 2766e6a7acab9..dd9212556c47d 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -135,6 +135,56 @@ def _dir_travel( excludes.pop() +def _hash_file_content_or_directory_name( + filepath: Path, + relative_path: Path, + logger: Optional[logging.Logger] = default_logger, +) -> bytes: + """Helper function to create hash of a single file or directory. + + This function hashes the path of the file or directory, + and if it's a file, then it hashes its content too. + """ + + BUF_SIZE = 4096 * 1024 + + sha1 = hashlib.sha1() + sha1.update(str(filepath.relative_to(relative_path)).encode()) + if not filepath.is_dir(): + try: + f = filepath.open("rb") + except Exception as e: + logger.debug( + f"Skipping contents of file {filepath} when calculating package hash " + f"because the file couldn't be opened: {e}" + ) + else: + try: + data = f.read(BUF_SIZE) + while len(data) != 0: + sha1.update(data) + data = f.read(BUF_SIZE) + finally: + f.close() + + return sha1.digest() + + +def _hash_file( + filepath: Path, + relative_path: Path, + logger: Optional[logging.Logger] = default_logger, +) -> bytes: + """Helper function to create hash of a single file. + + It hashes the path of the file and its content to create a hash value. + """ + file_hash = _hash_file_content_or_directory_name( + filepath, relative_path, logger=logger + ) + return _xor_bytes(file_hash, b"0" * 8) + + def _hash_directory( root: Path, relative_path: Path, @@ -147,30 +197,13 @@ def _hash_directory( hash(file_name, file_content) to create a hash value. """ hash_val = b"0" * 8 - BUF_SIZE = 4096 * 1024 def handler(path: Path): - sha1 = hashlib.sha1() - sha1.update(str(path.relative_to(relative_path)).encode()) - if not path.is_dir(): - try: - f = path.open("rb") - except Exception as e: - logger.debug( - f"Skipping contents of file {path} when calculating package hash " - f"because the file could not be opened: {e}" - ) - else: - try: - data = f.read(BUF_SIZE) - while len(data) != 0: - sha1.update(data) - data = f.read(BUF_SIZE) - finally: - f.close() - + file_hash = _hash_file_content_or_directory_name( + path, relative_path, logger=logger + ) nonlocal hash_val - hash_val = _xor_bytes(hash_val, sha1.digest()) + hash_val = _xor_bytes(hash_val, file_hash) excludes = [] if excludes is None else [excludes] _dir_travel(root, excludes, handler, logger=logger) @@ -378,16 +411,16 @@ def _get_local_path(base_directory: str, pkg_uri: str) -> str: return os.path.join(base_directory, pkg_name) -def _zip_directory( - directory: str, +def _zip_files( + path_str: str, excludes: List[str], output_path: str, include_parent_dir: bool = False, logger: Optional[logging.Logger] = default_logger, ) -> None: - """Zip the target directory and write it to the output_path. + """Zip the target file or directory and write it to the output_path. - directory: The directory to zip. + path_str: The file or directory to zip. excludes (List(str)): The directories or file to be excluded. output_path: The output path for the zip file. include_parent_dir: If true, includes the top-level directory as a @@ -396,7 +429,10 @@ def _zip_directory( pkg_file = Path(output_path).absolute() with ZipFile(pkg_file, "w", strict_timestamps=False) as zip_handler: # Put all files in the directory into the zip file. - dir_path = Path(directory).absolute() + file_path = Path(path_str).absolute() + dir_path = file_path + if file_path.is_file(): + dir_path = file_path.parent def handler(path: Path): # Pack this path if it's an empty directory or it's a file. @@ -415,8 +451,8 @@ def handler(path: Path): to_path = dir_path.name / to_path zip_handler.write(path, to_path) - excludes = [_get_excludes(dir_path, excludes)] - _dir_travel(dir_path, excludes, handler, logger=logger) + excludes = [_get_excludes(file_path, excludes)] + _dir_travel(file_path, excludes, handler, logger=logger) def package_exists(pkg_uri: str) -> bool: @@ -451,14 +487,47 @@ def get_uri_for_package(package: Path) -> str: ) +def get_uri_for_file(file: str) -> str: + """Get a content-addressable URI from a file's content. + + This function generates the name of the package by the file. + The final package name is _ray_pkg_.zip of this package, + where HASH_VAL is the hash value of the file. + For example: _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip + + Examples: + + >>> get_uri_for_file("/my_file.py") # doctest: +SKIP + _ray_pkg_af2734982a741.zip + + Args: + file: The file. + + Returns: + URI (str) + + Raises: + ValueError if the file doesn't exist. + """ + filepath = Path(file).absolute() + if not filepath.exists() or not filepath.is_file(): + raise ValueError(f"File {filepath} must be an existing file") + + hash_val = _hash_file(filepath, filepath.parent) + + return "{protocol}://{pkg_name}.zip".format( + protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val.hex() + ) + + def get_uri_for_directory(directory: str, excludes: Optional[List[str]] = None) -> str: """Get a content-addressable URI from a directory's contents. - This function will generate the name of the package by the directory. + This function generates the name of the package by the directory. It'll go through all the files in the directory and hash the contents of the files to get the hash value of the package. - The final package name is: _ray_pkg_.zip of this package. - e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip + The final package name is _ray_pkg_.zip of this package. + For example: _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip Examples: @@ -515,7 +584,7 @@ def upload_package_to_gcs(pkg_uri: str, pkg_bytes: bytes) -> None: def create_package( - directory: str, + module_path: str, target_path: Path, include_parent_dir: bool = False, excludes: Optional[List[str]] = None, @@ -528,11 +597,11 @@ def create_package( logger = default_logger if not target_path.exists(): - logger.info(f"Creating a file package for local directory '{directory}'.") - _zip_directory( - directory, + logger.info(f"Creating a file package for local module '{module_path}'.") + _zip_files( + module_path, excludes, - target_path, + str(target_path), include_parent_dir=include_parent_dir, logger=logger, ) @@ -541,7 +610,7 @@ def create_package( def upload_package_if_needed( pkg_uri: str, base_directory: str, - directory: str, + module_path: str, include_parent_dir: bool = False, excludes: Optional[List[str]] = None, logger: Optional[logging.Logger] = default_logger, @@ -556,7 +625,7 @@ def upload_package_if_needed( Args: pkg_uri: URI of the package to upload. base_directory: Directory where package files are stored. - directory: Directory to be uploaded. + module_path: The module to be uploaded, either a single .py file or a directory. include_parent_dir: If true, includes the top-level directory as a directory inside the zip file. excludes: List specifying files to exclude. @@ -586,7 +655,7 @@ def upload_package_if_needed( f"{time.time_ns()}_{os.getpid()}_{package_file.name}" ) create_package( - directory, + module_path, package_file, include_parent_dir=include_parent_dir, excludes=excludes, diff --git a/python/ray/_private/runtime_env/py_modules.py b/python/ray/_private/runtime_env/py_modules.py index 551e22b3650a8..1066cbe6126df 100644 --- a/python/ray/_private/runtime_env/py_modules.py +++ b/python/ray/_private/runtime_env/py_modules.py @@ -10,6 +10,7 @@ delete_package, download_and_unpack_package, get_local_dir_from_uri, + get_uri_for_file, get_uri_for_directory, get_uri_for_package, install_wheel_package, @@ -71,15 +72,20 @@ def upload_py_modules_if_needed( elif isinstance(module, Path): module_path = str(module) elif isinstance(module, ModuleType): - # NOTE(edoakes): Python allows some installed Python packages to - # be split into multiple directories. We could probably handle - # this, but it seems tricky & uncommon. If it's a problem for - # users, we can add this support on demand. - if len(module.__path__) > 1: - raise ValueError( - "py_modules only supports modules whose __path__ has length 1." - ) - [module_path] = module.__path__ + if not hasattr(module, "__path__"): + # This is a single-file module. + module_path = module.__file__ + else: + # NOTE(edoakes): Python allows some installed Python packages to + # be split into multiple directories. We could probably handle + # this, but it seems tricky & uncommon. If it's a problem for + # users, we can add this support on demand. + if len(module.__path__) > 1: + raise ValueError( + "py_modules only supports modules whose __path__" + " has length 1 or those who are single-file." + ) + [module_path] = module.__path__ else: raise TypeError( "py_modules must be a list of file paths, URIs, " @@ -90,9 +96,13 @@ def upload_py_modules_if_needed( module_uri = module_path else: # module_path is a local path. - if Path(module_path).is_dir(): + if Path(module_path).is_dir() or Path(module_path).suffix == ".py": + is_dir = Path(module_path).is_dir() excludes = runtime_env.get("excludes", None) - module_uri = get_uri_for_directory(module_path, excludes=excludes) + if is_dir: + module_uri = get_uri_for_directory(module_path, excludes=excludes) + else: + module_uri = get_uri_for_file(module_path) if upload_fn is None: try: upload_package_if_needed( @@ -100,7 +110,7 @@ def upload_py_modules_if_needed( scratch_dir, module_path, excludes=excludes, - include_parent_dir=True, + include_parent_dir=is_dir, logger=logger, ) except Exception as e: @@ -136,7 +146,8 @@ def upload_py_modules_if_needed( upload_fn(module_path, excludes=None, is_file=True) else: raise ValueError( - "py_modules entry must be a directory or a .whl file; " + "py_modules entry must be a .py file, " + "a directory, or a .whl file; " f"got {module_path}" ) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index e51315e70624f..0abfb5757692c 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -485,11 +485,17 @@ def __init__(self): # Cache the job id from initialize_job_config() to optimize lookups. # This is on the critical path of ray.get()/put() calls. self._cached_job_id = None + # Indicates whether the worker is connected to the Ray cluster. + # It should be set to True in `connect` and False in `disconnect`. + self._is_connected: bool = False @property def connected(self): """bool: True if Ray has been started and False otherwise.""" - return self.node is not None + return self._is_connected + + def set_is_connected(self, is_connected: bool): + self._is_connected = is_connected @property def node_ip_address(self): @@ -567,6 +573,17 @@ def debugger_port(self): worker_id = self.core_worker.get_worker_id() return ray._private.state.get_worker_debugger_port(worker_id) + @property + def job_logging_config(self): + """Get the job's logging config for this worker""" + if not hasattr(self, "core_worker"): + return None + job_config = self.core_worker.get_job_config() + if not job_config.serialized_py_logging_config: + return None + logging_config = pickle.loads(job_config.serialized_py_logging_config) + return logging_config + def set_debugger_port(self, port): worker_id = self.core_worker.get_worker_id() ray._private.state.update_worker_debugger_port(worker_id, port) @@ -1921,7 +1938,7 @@ def custom_excepthook(type, value, tb): sys.excepthook = custom_excepthook -def print_to_stdstream(data): +def print_to_stdstream(data, ignore_prefix: bool): should_dedup = data.get("pid") not in ["autoscaler"] if data["is_err"]: @@ -1938,7 +1955,7 @@ def print_to_stdstream(data): sink = sys.stdout for batch in batches: - print_worker_logs(batch, sink) + print_worker_logs(batch, sink, ignore_prefix) # Start time of this process, used for relative time logs. @@ -2029,7 +2046,9 @@ def time_string() -> str: _worker_logs_enabled = True -def print_worker_logs(data: Dict[str, str], print_file: Any): +def print_worker_logs( + data: Dict[str, str], print_file: Any, ignore_prefix: bool = False +): if not _worker_logs_enabled: return @@ -2109,11 +2128,19 @@ def color_for(data: Dict[str, str], line: str) -> str: else: color_pre = color_for(data, line) color_post = colorama.Style.RESET_ALL - print( - f"{color_pre}({prefix_for(data)}{pid}{ip_prefix}){color_post} " - f"{message_for(data, line)}", - file=print_file, - ) + + if ignore_prefix: + print( + f"{message_for(data, line)}", + file=print_file, + ) + else: + print( + f"{color_pre}({prefix_for(data)}{pid}{ip_prefix}){color_post} " + f"{message_for(data, line)}", + file=print_file, + ) + # Restore once at end of batch to avoid excess hiding/unhiding of tqdm. restore_tqdm() @@ -2163,7 +2190,6 @@ def listen_error_messages(worker, threads_stopped): error_message = _internal_kv_get(ray_constants.DEBUG_AUTOSCALING_ERROR) if error_message is not None: logger.warning(error_message.decode()) - while True: # Exit if received a signal that the thread should stop. if threads_stopped.is_set(): @@ -2184,7 +2210,8 @@ def listen_error_messages(worker, threads_stopped): "lines": [error_message], "pid": "raylet", "is_err": False, - } + }, + ignore_prefix=False, ) except (OSError, ConnectionError) as e: logger.error(f"listen_error_messages: {e}") @@ -2462,9 +2489,14 @@ def connect( ) worker.listener_thread.daemon = True worker.listener_thread.start() + # If the job's logging config is set, don't add the prefix + # (task/actor's name and its PID) to the logs. + ignore_prefix = global_worker.job_logging_config is not None + if log_to_driver: global_worker_stdstream_dispatcher.add_handler( - "ray_print_logs", print_to_stdstream + "ray_print_logs", + functools.partial(print_to_stdstream, ignore_prefix=ignore_prefix), ) worker.logger_thread = threading.Thread( target=worker.print_logs, name="ray_print_logs" @@ -2483,6 +2515,9 @@ def connect( _setup_tracing() ray.__traced__ = True + # Mark the worker as connected. + worker.set_is_connected(True) + def disconnect(exiting_interpreter=False): """Disconnect this worker from the raylet and object store.""" @@ -2506,10 +2541,12 @@ def disconnect(exiting_interpreter=False): worker.logger_thread.join() worker.threads_stopped.clear() + # Ignore the prefix if the logging config is set. + ignore_prefix = worker.job_logging_config is not None for leftover in stdout_deduplicator.flush(): - print_worker_logs(leftover, sys.stdout) + print_worker_logs(leftover, sys.stdout, ignore_prefix) for leftover in stderr_deduplicator.flush(): - print_worker_logs(leftover, sys.stderr) + print_worker_logs(leftover, sys.stderr, ignore_prefix) global_worker_stdstream_dispatcher.remove_handler("ray_print_logs") worker.node = None # Disconnect the worker from the node. @@ -2521,6 +2558,9 @@ def disconnect(exiting_interpreter=False): if ray_actor is not None: ray_actor._ActorClassMethodMetadata.reset_cache() + # Mark the worker as disconnected. + worker.set_is_connected(False) + @contextmanager def _changeproctitle(title, next_title): diff --git a/python/ray/air/integrations/wandb.py b/python/ray/air/integrations/wandb.py index fcd683ec0e7f7..2d2d5ffe27296 100644 --- a/python/ray/air/integrations/wandb.py +++ b/python/ray/air/integrations/wandb.py @@ -642,6 +642,11 @@ def log_trial_start(self, trial: "Trial"): def _start_logging_actor( self, trial: "Trial", exclude_results: List[str], **wandb_init_kwargs ): + # Reuse actor if one already exists. + # This can happen if the trial is restarted. + if trial in self._trial_logging_futures: + return + if not self._remote_logger_class: env_vars = {} # API key env variable is not set if authenticating through `wandb login` diff --git a/python/ray/air/tests/test_integration_wandb.py b/python/ray/air/tests/test_integration_wandb.py index abf1576407d79..f2f88aa523e32 100644 --- a/python/ray/air/tests/test_integration_wandb.py +++ b/python/ray/air/tests/test_integration_wandb.py @@ -483,6 +483,26 @@ def _handle_result(self, result): state = ray.get(actor.get_state.remote()) assert [metrics["training_iteration"] for metrics in state.logs] == [4, 5] + def test_wandb_restart(self, trial): + """Test that the WandbLoggerCallback reuses actors for trial restarts.""" + + logger = WandbLoggerCallback(project="test_project", api_key="1234") + logger._logger_actor_cls = _MockWandbLoggingActor + logger.setup() + + assert len(logger._trial_logging_futures) == 0 + assert len(logger._logging_future_to_trial) == 0 + + logger.log_trial_start(trial) + + assert len(logger._trial_logging_futures) == 1 + assert len(logger._logging_future_to_trial) == 1 + + logger.log_trial_start(trial) + + assert len(logger._trial_logging_futures) == 1 + assert len(logger._logging_future_to_trial) == 1 + def test_wandb_logging_process_run_info_hook(monkeypatch): """ diff --git a/python/ray/dag/tests/experimental/test_execution_schedule.py b/python/ray/dag/tests/experimental/test_execution_schedule.py index 9477177fdc912..8633107a2c25b 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule.py @@ -35,12 +35,12 @@ def mock_init(self): pass -def generate_dag_graph_nodes(local_idx, dag_idx, actor_handle, requires_nccl): +def generate_dag_graph_nodes(exec_task_idx, task_idx, actor_handle, requires_nccl): graph_nodes = {} for op_type in _DAGNodeOperationType: graph_nodes[op_type] = _DAGOperationGraphNode( - _DAGNodeOperation(local_idx, op_type), - dag_idx, + _DAGNodeOperation(exec_task_idx, op_type), + task_idx, actor_handle, requires_nccl, ) @@ -52,8 +52,8 @@ class TestSelectNextNodes: Test whether `_select_next_nodes` function selects the next nodes for topological sort to generate execution schedule correctly. - dag_idx: Each DAG node has a unique global index. - local_idx: The DAG node's index in the actor's `executable_tasks` list. + task_idx: Each DAG node has a unique global index. + exec_task_idx: The DAG node's index in the actor's `executable_tasks` list. """ def test_two_candidates_on_same_actor(self, monkeypatch): @@ -73,19 +73,19 @@ def test_two_candidates_on_same_actor(self, monkeypatch): fake_actor = ActorHandle("fake_actor") # The DAG node has a global index of 1, and its index in the # actor's `executable_tasks` list is 0. - dag_idx_1 = 1 + task_idx_1 = 1 dag_node_1 = _DAGOperationGraphNode( _DAGNodeOperation(0, _DAGNodeOperationType.READ), - dag_idx_1, + task_idx_1, fake_actor, False, ) # The DAG node has a global index of 2, and its index in the # actor's `executable_tasks` list is 1. - dag_idx_2 = 2 + task_idx_2 = 2 dag_node_2 = _DAGOperationGraphNode( _DAGNodeOperation(1, _DAGNodeOperationType.READ), - dag_idx_2, + task_idx_2, fake_actor, False, ) @@ -113,38 +113,38 @@ def test_only_one_nccl_write(self, monkeypatch): execution schedule. """ monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) - fake_actor_1, dag_idx_1, local_idx_1 = ActorHandle("fake_actor_1"), 1, 0 - fake_actor_2, dag_idx_2, local_idx_2 = ActorHandle("fake_actor_2"), 2, 0 + fake_actor_1, task_idx_1, exec_task_idx_1 = ActorHandle("fake_actor_1"), 1, 0 + fake_actor_2, task_idx_2, exec_task_idx_2 = ActorHandle("fake_actor_2"), 2, 0 mock_graph = { - dag_idx_1: generate_dag_graph_nodes( - local_idx_1, dag_idx_1, fake_actor_1, True + task_idx_1: generate_dag_graph_nodes( + exec_task_idx_1, task_idx_1, fake_actor_1, True ), - dag_idx_2: generate_dag_graph_nodes( - local_idx_2, dag_idx_2, fake_actor_2, False + task_idx_2: generate_dag_graph_nodes( + exec_task_idx_2, task_idx_2, fake_actor_2, False ), } - del mock_graph[dag_idx_1][_DAGNodeOperationType.READ] - del mock_graph[dag_idx_1][_DAGNodeOperationType.COMPUTE] + del mock_graph[task_idx_1][_DAGNodeOperationType.READ] + del mock_graph[task_idx_1][_DAGNodeOperationType.COMPUTE] _add_edge( - mock_graph[dag_idx_1][_DAGNodeOperationType.WRITE], - mock_graph[dag_idx_2][_DAGNodeOperationType.READ], + mock_graph[task_idx_1][_DAGNodeOperationType.WRITE], + mock_graph[task_idx_2][_DAGNodeOperationType.READ], ) _add_edge( - mock_graph[dag_idx_2][_DAGNodeOperationType.READ], - mock_graph[dag_idx_2][_DAGNodeOperationType.COMPUTE], + mock_graph[task_idx_2][_DAGNodeOperationType.READ], + mock_graph[task_idx_2][_DAGNodeOperationType.COMPUTE], ) _add_edge( - mock_graph[dag_idx_2][_DAGNodeOperationType.COMPUTE], - mock_graph[dag_idx_2][_DAGNodeOperationType.WRITE], + mock_graph[task_idx_2][_DAGNodeOperationType.COMPUTE], + mock_graph[task_idx_2][_DAGNodeOperationType.WRITE], ) mock_actor_to_candidates = { - fake_actor_1: [mock_graph[dag_idx_1][_DAGNodeOperationType.WRITE]], + fake_actor_1: [mock_graph[task_idx_1][_DAGNodeOperationType.WRITE]], } next_nodes = _select_next_nodes(mock_actor_to_candidates, mock_graph) assert len(next_nodes) == 2 - assert next_nodes[0] == mock_graph[dag_idx_1][_DAGNodeOperationType.WRITE] - assert next_nodes[1] == mock_graph[dag_idx_2][_DAGNodeOperationType.READ] + assert next_nodes[0] == mock_graph[task_idx_1][_DAGNodeOperationType.WRITE] + assert next_nodes[1] == mock_graph[task_idx_2][_DAGNodeOperationType.READ] def test_two_nccl_writes(self, monkeypatch): """ @@ -164,67 +164,69 @@ def test_two_nccl_writes(self, monkeypatch): monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) fake_actor_1 = ActorHandle("fake_actor_1") - dag_idx_1_0, local_idx_1_0 = 1, 0 - dag_idx_1_1, local_idx_1_1 = 3, 1 + task_idx_1_0, exec_task_idx_1_0 = 1, 0 + task_idx_1_1, exec_task_idx_1_1 = 3, 1 fake_actor_2 = ActorHandle("fake_actor_2") - dag_idx_2_0, local_idx_2_0 = 2, 0 - dag_idx_2_1, local_idx_2_1 = 4, 1 + task_idx_2_0, exec_task_idx_2_0 = 2, 0 + task_idx_2_1, exec_task_idx_2_1 = 4, 1 # Run the test 10 times to ensure that the result of `_select_next_nodes` # is deterministic. for _ in range(20): mock_graph = { - dag_idx_1_0: generate_dag_graph_nodes( - local_idx_1_0, dag_idx_1_0, fake_actor_1, True + task_idx_1_0: generate_dag_graph_nodes( + exec_task_idx_1_0, task_idx_1_0, fake_actor_1, True ), - dag_idx_1_1: generate_dag_graph_nodes( - local_idx_1_1, dag_idx_1_1, fake_actor_1, False + task_idx_1_1: generate_dag_graph_nodes( + exec_task_idx_1_1, task_idx_1_1, fake_actor_1, False ), - dag_idx_2_0: generate_dag_graph_nodes( - local_idx_2_0, dag_idx_2_0, fake_actor_2, True + task_idx_2_0: generate_dag_graph_nodes( + exec_task_idx_2_0, task_idx_2_0, fake_actor_2, True ), - dag_idx_2_1: generate_dag_graph_nodes( - local_idx_2_1, dag_idx_2_1, fake_actor_2, False + task_idx_2_1: generate_dag_graph_nodes( + exec_task_idx_2_1, task_idx_2_1, fake_actor_2, False ), } - del mock_graph[dag_idx_1_0][_DAGNodeOperationType.READ] - del mock_graph[dag_idx_1_0][_DAGNodeOperationType.COMPUTE] - del mock_graph[dag_idx_2_0][_DAGNodeOperationType.READ] - del mock_graph[dag_idx_2_0][_DAGNodeOperationType.COMPUTE] + del mock_graph[task_idx_1_0][_DAGNodeOperationType.READ] + del mock_graph[task_idx_1_0][_DAGNodeOperationType.COMPUTE] + del mock_graph[task_idx_2_0][_DAGNodeOperationType.READ] + del mock_graph[task_idx_2_0][_DAGNodeOperationType.COMPUTE] _add_edge( - mock_graph[dag_idx_1_0][_DAGNodeOperationType.WRITE], - mock_graph[dag_idx_2_1][_DAGNodeOperationType.READ], + mock_graph[task_idx_1_0][_DAGNodeOperationType.WRITE], + mock_graph[task_idx_2_1][_DAGNodeOperationType.READ], ) _add_edge( - mock_graph[dag_idx_2_0][_DAGNodeOperationType.WRITE], - mock_graph[dag_idx_1_1][_DAGNodeOperationType.READ], + mock_graph[task_idx_2_0][_DAGNodeOperationType.WRITE], + mock_graph[task_idx_1_1][_DAGNodeOperationType.READ], ) _add_edge( - mock_graph[dag_idx_2_1][_DAGNodeOperationType.READ], - mock_graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE], + mock_graph[task_idx_2_1][_DAGNodeOperationType.READ], + mock_graph[task_idx_2_1][_DAGNodeOperationType.COMPUTE], ) _add_edge( - mock_graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE], - mock_graph[dag_idx_2_1][_DAGNodeOperationType.WRITE], + mock_graph[task_idx_2_1][_DAGNodeOperationType.COMPUTE], + mock_graph[task_idx_2_1][_DAGNodeOperationType.WRITE], ) _add_edge( - mock_graph[dag_idx_1_1][_DAGNodeOperationType.READ], - mock_graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE], + mock_graph[task_idx_1_1][_DAGNodeOperationType.READ], + mock_graph[task_idx_1_1][_DAGNodeOperationType.COMPUTE], ) _add_edge( - mock_graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE], - mock_graph[dag_idx_1_1][_DAGNodeOperationType.WRITE], + mock_graph[task_idx_1_1][_DAGNodeOperationType.COMPUTE], + mock_graph[task_idx_1_1][_DAGNodeOperationType.WRITE], ) mock_actor_to_candidates = { - fake_actor_1: [mock_graph[dag_idx_1_0][_DAGNodeOperationType.WRITE]], - fake_actor_2: [mock_graph[dag_idx_2_0][_DAGNodeOperationType.WRITE]], + fake_actor_1: [mock_graph[task_idx_1_0][_DAGNodeOperationType.WRITE]], + fake_actor_2: [mock_graph[task_idx_2_0][_DAGNodeOperationType.WRITE]], } next_nodes = _select_next_nodes(mock_actor_to_candidates, mock_graph) assert len(next_nodes) == 2 - assert next_nodes[0] == mock_graph[dag_idx_1_0][_DAGNodeOperationType.WRITE] - assert next_nodes[1] == mock_graph[dag_idx_2_1][_DAGNodeOperationType.READ] + assert ( + next_nodes[0] == mock_graph[task_idx_1_0][_DAGNodeOperationType.WRITE] + ) + assert next_nodes[1] == mock_graph[task_idx_2_1][_DAGNodeOperationType.READ] class TestBuildDAGNodeOperationGraph: @@ -237,7 +239,7 @@ class TestBuildDAGNodeOperationGraph: def check_edges_between_read_compute_write( self, graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], - dag_idx: int, + task_idx: int, expected_num_edges: List[Tuple[int, int]], ): """ @@ -246,54 +248,54 @@ def check_edges_between_read_compute_write( Args: graph: The operation graph generated by `_build_dag_node_operation_graph`. - dag_idx: The global index of the task used to access the task in + task_idx: The global index of the task used to access the task in `idx_to_task`. expected_num_edges: A list of tuples where each tuple contains the expected number of in-edges and out-edges for READ, COMPUTE, and WRITE operations. """ assert len(expected_num_edges) == 3 - assert len(graph[dag_idx]) == 3 - read_node = graph[dag_idx][_DAGNodeOperationType.READ] - compute_node = graph[dag_idx][_DAGNodeOperationType.COMPUTE] - write_node = graph[dag_idx][_DAGNodeOperationType.WRITE] + assert len(graph[task_idx]) == 3 + read_node = graph[task_idx][_DAGNodeOperationType.READ] + compute_node = graph[task_idx][_DAGNodeOperationType.COMPUTE] + write_node = graph[task_idx][_DAGNodeOperationType.WRITE] for idx, node in enumerate([read_node, compute_node, write_node]): assert node.in_degree == expected_num_edges[idx][0] assert len(node.out_edges) == expected_num_edges[idx][1] - assert (dag_idx, _DAGNodeOperationType.COMPUTE) in read_node.out_edges - assert (dag_idx, _DAGNodeOperationType.READ) in compute_node.in_edges - assert (dag_idx, _DAGNodeOperationType.WRITE) in compute_node.out_edges - assert (dag_idx, _DAGNodeOperationType.COMPUTE) in write_node.in_edges + assert (task_idx, _DAGNodeOperationType.COMPUTE) in read_node.out_edges + assert (task_idx, _DAGNodeOperationType.READ) in compute_node.in_edges + assert (task_idx, _DAGNodeOperationType.WRITE) in compute_node.out_edges + assert (task_idx, _DAGNodeOperationType.COMPUTE) in write_node.in_edges def check_edge_between_writer_and_reader( self, graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], - writer_dag_idx: int, - reader_dag_idx: int, + writer_task_idx: int, + reader_task_idx: int, ): """ Check whether the edge from writer's WRITE to reader's READ operation is added. Args: graph: The operation graph generated by `_build_dag_node_operation_graph`. - writer_dag_idx: The index of the task used to access the task + writer_task_idx: The index of the task used to access the task that the writer belongs to in `idx_to_task`. - reader_dag_idx: The index of the task used to access the task + reader_task_idx: The index of the task used to access the task that the reader belongs to in `idx_to_task`. """ - write_node = graph[writer_dag_idx][_DAGNodeOperationType.WRITE] - read_node = graph[reader_dag_idx][_DAGNodeOperationType.READ] + write_node = graph[writer_task_idx][_DAGNodeOperationType.WRITE] + read_node = graph[reader_task_idx][_DAGNodeOperationType.READ] - assert (reader_dag_idx, _DAGNodeOperationType.READ) in write_node.out_edges - assert (writer_dag_idx, _DAGNodeOperationType.WRITE) in read_node.in_edges + assert (reader_task_idx, _DAGNodeOperationType.READ) in write_node.out_edges + assert (writer_task_idx, _DAGNodeOperationType.WRITE) in read_node.in_edges def check_edge_between_compute_nodes( self, graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], - dag_idx_1: int, - dag_idx_2: int, + task_idx_1: int, + task_idx_2: int, ): """ Check whether the edge from COMPUTE with `bind_index` i to COMPUTE with @@ -301,18 +303,18 @@ def check_edge_between_compute_nodes( Args: graph: The operation graph generated by `_build_dag_node_operation_graph`. - dag_idx_1: The index of the task used to access the task in + task_idx_1: The index of the task used to access the task in `idx_to_task`. - dag_idx_2: The index of the task used to access the task in + task_idx_2: The index of the task used to access the task in `idx_to_task`. Note that both tasks belong to the same actor, and the `bind_index` of the second task is equal to the `bind_index` of the first task plus one. """ - compute_node_1 = graph[dag_idx_1][_DAGNodeOperationType.COMPUTE] - compute_node_2 = graph[dag_idx_2][_DAGNodeOperationType.COMPUTE] + compute_node_1 = graph[task_idx_1][_DAGNodeOperationType.COMPUTE] + compute_node_2 = graph[task_idx_2][_DAGNodeOperationType.COMPUTE] - assert (dag_idx_2, _DAGNodeOperationType.COMPUTE) in compute_node_1.out_edges - assert (dag_idx_1, _DAGNodeOperationType.COMPUTE) in compute_node_2.in_edges + assert (task_idx_2, _DAGNodeOperationType.COMPUTE) in compute_node_1.out_edges + assert (task_idx_1, _DAGNodeOperationType.COMPUTE) in compute_node_2.in_edges def test_edges_between_read_compute_write(self, monkeypatch): """ @@ -331,17 +333,17 @@ def test_edges_between_read_compute_write(self, monkeypatch): } fake_actor = "fake_actor" - dag_idx = 1 + task_idx = 1 actor_to_operation_nodes = { fake_actor: [ - list(generate_dag_graph_nodes(0, dag_idx, fake_actor, False).values()) + list(generate_dag_graph_nodes(0, task_idx, fake_actor, False).values()) ] } graph = _build_dag_node_operation_graph(idx_to_task, actor_to_operation_nodes) assert len(graph) == 1 self.check_edges_between_read_compute_write( - graph, dag_idx, [(0, 1), (1, 1), (1, 0)] + graph, task_idx, [(0, 1), (1, 1), (1, 0)] ) def test_edge_between_writer_and_reader(self, monkeypatch): @@ -354,8 +356,8 @@ def test_edge_between_writer_and_reader(self, monkeypatch): monkeypatch.setattr(ClassMethodNode, "__init__", mock_class_method_call_init) monkeypatch.setattr(MultiOutputNode, "__init__", mock_init) - fake_actor_1, dag_idx_1 = "fake_actor_1", 1 - fake_actor_2, dag_idx_2 = "fake_actor_2", 2 + fake_actor_1, task_idx_1 = "fake_actor_1", 1 + fake_actor_2, task_idx_2 = "fake_actor_2", 2 idx_to_task = { 0: CompiledTask(0, InputNode()), 1: CompiledTask(1, ClassMethodNode()), @@ -367,12 +369,16 @@ def test_edge_between_writer_and_reader(self, monkeypatch): actor_to_operation_nodes = { fake_actor_1: [ list( - generate_dag_graph_nodes(0, dag_idx_1, fake_actor_1, False).values() + generate_dag_graph_nodes( + 0, task_idx_1, fake_actor_1, False + ).values() ) ], fake_actor_2: [ list( - generate_dag_graph_nodes(0, dag_idx_2, fake_actor_2, False).values() + generate_dag_graph_nodes( + 0, task_idx_2, fake_actor_2, False + ).values() ) ], } @@ -380,12 +386,12 @@ def test_edge_between_writer_and_reader(self, monkeypatch): assert len(graph) == 2 self.check_edges_between_read_compute_write( - graph, dag_idx_1, [(0, 1), (1, 1), (1, 1)] + graph, task_idx_1, [(0, 1), (1, 1), (1, 1)] ) self.check_edges_between_read_compute_write( - graph, dag_idx_2, [(1, 1), (1, 1), (1, 0)] + graph, task_idx_2, [(1, 1), (1, 1), (1, 0)] ) - self.check_edge_between_writer_and_reader(graph, dag_idx_1, dag_idx_2) + self.check_edge_between_writer_and_reader(graph, task_idx_1, task_idx_2) def test_edge_between_compute_nodes(self, monkeypatch): """ @@ -399,22 +405,22 @@ def test_edge_between_compute_nodes(self, monkeypatch): monkeypatch.setattr(MultiOutputNode, "__init__", mock_init) fake_actor = "fake_actor" - dag_idx_1, dag_idx_2 = 1, 2 + task_idx_1, task_idx_2 = 1, 2 idx_to_task = { 0: CompiledTask(0, InputNode()), - dag_idx_1: CompiledTask(dag_idx_1, ClassMethodNode()), - dag_idx_2: CompiledTask(dag_idx_2, ClassMethodNode()), + task_idx_1: CompiledTask(task_idx_1, ClassMethodNode()), + task_idx_2: CompiledTask(task_idx_2, ClassMethodNode()), 3: CompiledTask(3, MultiOutputNode()), } - idx_to_task[dag_idx_1].downstream_task_idxs = {dag_idx_2: fake_actor} + idx_to_task[task_idx_1].downstream_task_idxs = {task_idx_2: fake_actor} actor_to_operation_nodes = { fake_actor: [ list( - generate_dag_graph_nodes(0, dag_idx_1, fake_actor, False).values() + generate_dag_graph_nodes(0, task_idx_1, fake_actor, False).values() ), list( - generate_dag_graph_nodes(1, dag_idx_2, fake_actor, False).values() + generate_dag_graph_nodes(1, task_idx_2, fake_actor, False).values() ), ], } @@ -422,13 +428,13 @@ def test_edge_between_compute_nodes(self, monkeypatch): assert len(graph) == 2 self.check_edges_between_read_compute_write( - graph, dag_idx_1, [(0, 1), (1, 2), (1, 1)] + graph, task_idx_1, [(0, 1), (1, 2), (1, 1)] ) self.check_edges_between_read_compute_write( - graph, dag_idx_2, [(1, 1), (2, 1), (1, 0)] + graph, task_idx_2, [(1, 1), (2, 1), (1, 0)] ) - self.check_edge_between_writer_and_reader(graph, dag_idx_1, dag_idx_2) - self.check_edge_between_compute_nodes(graph, dag_idx_1, dag_idx_2) + self.check_edge_between_writer_and_reader(graph, task_idx_1, task_idx_2) + self.check_edge_between_compute_nodes(graph, task_idx_1, task_idx_2) def test_two_actors(self, monkeypatch): """ @@ -443,35 +449,43 @@ def test_two_actors(self, monkeypatch): monkeypatch.setattr(ClassMethodNode, "__init__", mock_class_method_call_init) monkeypatch.setattr(MultiOutputNode, "__init__", mock_init) - fake_actor_1, dag_idx_1, dag_idx_3 = "fake_actor_1", 1, 3 - fake_actor_2, dag_idx_2, dag_idx_4 = "fake_actor_2", 2, 4 + fake_actor_1, task_idx_1, task_idx_3 = "fake_actor_1", 1, 3 + fake_actor_2, task_idx_2, task_idx_4 = "fake_actor_2", 2, 4 idx_to_task = { 0: CompiledTask(0, InputNode()), - dag_idx_1: CompiledTask(dag_idx_1, ClassMethodNode()), - dag_idx_2: CompiledTask(dag_idx_2, ClassMethodNode()), - dag_idx_3: CompiledTask(dag_idx_3, ClassMethodNode()), - dag_idx_4: CompiledTask(dag_idx_4, ClassMethodNode()), + task_idx_1: CompiledTask(task_idx_1, ClassMethodNode()), + task_idx_2: CompiledTask(task_idx_2, ClassMethodNode()), + task_idx_3: CompiledTask(task_idx_3, ClassMethodNode()), + task_idx_4: CompiledTask(task_idx_4, ClassMethodNode()), 5: CompiledTask(5, MultiOutputNode()), } - idx_to_task[dag_idx_1].downstream_task_idxs = {dag_idx_4: fake_actor_2} - idx_to_task[dag_idx_2].downstream_task_idxs = {dag_idx_3: fake_actor_1} + idx_to_task[task_idx_1].downstream_task_idxs = {task_idx_4: fake_actor_2} + idx_to_task[task_idx_2].downstream_task_idxs = {task_idx_3: fake_actor_1} actor_to_operation_nodes = { fake_actor_1: [ list( - generate_dag_graph_nodes(0, dag_idx_1, fake_actor_1, False).values() + generate_dag_graph_nodes( + 0, task_idx_1, fake_actor_1, False + ).values() ), list( - generate_dag_graph_nodes(1, dag_idx_3, fake_actor_1, False).values() + generate_dag_graph_nodes( + 1, task_idx_3, fake_actor_1, False + ).values() ), ], fake_actor_2: [ list( - generate_dag_graph_nodes(0, dag_idx_2, fake_actor_2, False).values() + generate_dag_graph_nodes( + 0, task_idx_2, fake_actor_2, False + ).values() ), list( - generate_dag_graph_nodes(1, dag_idx_4, fake_actor_2, False).values() + generate_dag_graph_nodes( + 1, task_idx_4, fake_actor_2, False + ).values() ), ], } @@ -479,19 +493,19 @@ def test_two_actors(self, monkeypatch): assert len(graph) == 4 self.check_edges_between_read_compute_write( - graph, dag_idx_1, [(0, 1), (1, 2), (1, 1)] + graph, task_idx_1, [(0, 1), (1, 2), (1, 1)] ) self.check_edges_between_read_compute_write( - graph, dag_idx_2, [(0, 1), (1, 2), (1, 1)] + graph, task_idx_2, [(0, 1), (1, 2), (1, 1)] ) self.check_edges_between_read_compute_write( - graph, dag_idx_3, [(1, 1), (2, 1), (1, 0)] + graph, task_idx_3, [(1, 1), (2, 1), (1, 0)] ) self.check_edges_between_read_compute_write( - graph, dag_idx_4, [(1, 1), (2, 1), (1, 0)] + graph, task_idx_4, [(1, 1), (2, 1), (1, 0)] ) - self.check_edge_between_writer_and_reader(graph, dag_idx_1, dag_idx_4) - self.check_edge_between_writer_and_reader(graph, dag_idx_2, dag_idx_3) + self.check_edge_between_writer_and_reader(graph, task_idx_1, task_idx_4) + self.check_edge_between_writer_and_reader(graph, task_idx_2, task_idx_3) class TestGenerateActorToExecutionSchedule: @@ -564,7 +578,7 @@ def add_control_dependency( def test_single_actor_1(self, monkeypatch): """ - driver -> fake_actor.op (dag_idx_1) -> fake_actor.op (dag_idx_2) -> driver + driver -> fake_actor.op (task_idx_1) -> fake_actor.op (task_idx_2) -> driver Test the case where there is only one actor and no NCCL operations. Because there is no NCCL operation, all operations with smaller @@ -574,90 +588,90 @@ def test_single_actor_1(self, monkeypatch): monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) fake_actor = ActorHandle("fake_actor") - dag_idx_1, local_idx_1 = 1, 0 - dag_idx_2, local_idx_2 = 2, 1 + task_idx_1, exec_task_idx_1 = 1, 0 + task_idx_2, exec_task_idx_2 = 2, 1 graph = { - dag_idx_1: generate_dag_graph_nodes( - local_idx_1, dag_idx_1, fake_actor, False + task_idx_1: generate_dag_graph_nodes( + exec_task_idx_1, task_idx_1, fake_actor, False ), - dag_idx_2: generate_dag_graph_nodes( - local_idx_2, dag_idx_2, fake_actor, False + task_idx_2: generate_dag_graph_nodes( + exec_task_idx_2, task_idx_2, fake_actor, False ), } - self.add_edge_between_read_compute_write(graph[dag_idx_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_2]) - self.add_data_dependeny(graph[dag_idx_1], graph[dag_idx_2]) - self.add_control_dependency(graph[dag_idx_1], graph[dag_idx_2]) + self.add_edge_between_read_compute_write(graph[task_idx_1]) + self.add_edge_between_read_compute_write(graph[task_idx_2]) + self.add_data_dependeny(graph[task_idx_1], graph[task_idx_2]) + self.add_control_dependency(graph[task_idx_1], graph[task_idx_2]) actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) assert len(actor_to_execution_schedule) == 1 assert len(actor_to_execution_schedule[fake_actor]) == 6 assert actor_to_execution_schedule[fake_actor] == [ - graph[dag_idx_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2][_DAGNodeOperationType.WRITE].operation, ] def test_single_actor_2(self, monkeypatch): """ - driver -> fake_actor.op (dag_idx_1) -> fake_actor.op (dag_idx_2) -> driver + driver -> fake_actor.op (task_idx_1) -> fake_actor.op (task_idx_2) -> driver | | - -> fake_actor.op (dag_idx_3) - + -> fake_actor.op (task_idx_3) - - When the `dad_idx_1.WRITE` operation is picked, both `dag_idx_2.READ` and - `dag_idx_3.READ` operations should be zero in-degree. In this case, the one + When the `dad_idx_1.WRITE` operation is picked, both `task_idx_2.READ` and + `task_idx_3.READ` operations should be zero in-degree. In this case, the one with the smaller `bind_index` should be selected first. That is, - `dag_idx_2.READ` should be selected first. + `task_idx_2.READ` should be selected first. """ monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) fake_actor = ActorHandle("fake_actor") - dag_idx_1, local_idx_1 = 1, 0 - dag_idx_2, local_idx_2 = 2, 1 - dag_idx_3, local_idx_3 = 3, 2 + task_idx_1, exec_task_idx_1 = 1, 0 + task_idx_2, exec_task_idx_2 = 2, 1 + task_idx_3, exec_task_idx_3 = 3, 2 graph = { - dag_idx_1: generate_dag_graph_nodes( - local_idx_1, dag_idx_1, fake_actor, False + task_idx_1: generate_dag_graph_nodes( + exec_task_idx_1, task_idx_1, fake_actor, False ), - dag_idx_2: generate_dag_graph_nodes( - local_idx_2, dag_idx_2, fake_actor, False + task_idx_2: generate_dag_graph_nodes( + exec_task_idx_2, task_idx_2, fake_actor, False ), - dag_idx_3: generate_dag_graph_nodes( - local_idx_3, dag_idx_3, fake_actor, False + task_idx_3: generate_dag_graph_nodes( + exec_task_idx_3, task_idx_3, fake_actor, False ), } - self.add_edge_between_read_compute_write(graph[dag_idx_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_2]) - self.add_edge_between_read_compute_write(graph[dag_idx_3]) - self.add_data_dependeny(graph[dag_idx_1], graph[dag_idx_2]) - self.add_data_dependeny(graph[dag_idx_1], graph[dag_idx_3]) - self.add_control_dependency(graph[dag_idx_1], graph[dag_idx_2]) - self.add_control_dependency(graph[dag_idx_2], graph[dag_idx_3]) + self.add_edge_between_read_compute_write(graph[task_idx_1]) + self.add_edge_between_read_compute_write(graph[task_idx_2]) + self.add_edge_between_read_compute_write(graph[task_idx_3]) + self.add_data_dependeny(graph[task_idx_1], graph[task_idx_2]) + self.add_data_dependeny(graph[task_idx_1], graph[task_idx_3]) + self.add_control_dependency(graph[task_idx_1], graph[task_idx_2]) + self.add_control_dependency(graph[task_idx_2], graph[task_idx_3]) actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) assert len(actor_to_execution_schedule) == 1 assert len(actor_to_execution_schedule[fake_actor]) == 9 assert actor_to_execution_schedule[fake_actor] == [ - graph[dag_idx_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_3][_DAGNodeOperationType.READ].operation, - graph[dag_idx_3][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_3][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_3][_DAGNodeOperationType.READ].operation, + graph[task_idx_3][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_3][_DAGNodeOperationType.WRITE].operation, ] def test_two_actors_no_nccl(self, monkeypatch): """ - driver -> actor_1.op (dag_idx_1_1) -> actor_2.op (dag_idx_2_2) -> driver + driver -> actor_1.op (task_idx_1_1) -> actor_2.op (task_idx_2_2) -> driver | | - -> actor_2.op (dag_idx_2_1) -> actor_1.op (dag_idx_1_2) - + -> actor_2.op (task_idx_2_1) -> actor_1.op (task_idx_1_2) - Test the case where there are two actors and no NCCL operations. Because there is no NCCL operation, all operations with smaller @@ -667,35 +681,35 @@ def test_two_actors_no_nccl(self, monkeypatch): monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) fake_actor_1 = ActorHandle("fake_actor_1") - dag_idx_1_1, local_idx_1_1 = 1, 0 - dag_idx_1_2, local_idx_1_2 = 4, 1 + task_idx_1_1, exec_task_idx_1_1 = 1, 0 + task_idx_1_2, exec_task_idx_1_2 = 4, 1 fake_actor_2 = ActorHandle("fake_actor_2") - dag_idx_2_1, local_idx_2_1 = 2, 0 - dag_idx_2_2, local_idx_2_2 = 3, 1 + task_idx_2_1, exec_task_idx_2_1 = 2, 0 + task_idx_2_2, exec_task_idx_2_2 = 3, 1 graph = { - dag_idx_1_1: generate_dag_graph_nodes( - local_idx_1_1, dag_idx_1_1, fake_actor_1, False + task_idx_1_1: generate_dag_graph_nodes( + exec_task_idx_1_1, task_idx_1_1, fake_actor_1, False ), - dag_idx_2_1: generate_dag_graph_nodes( - local_idx_2_1, dag_idx_2_1, fake_actor_2, False + task_idx_2_1: generate_dag_graph_nodes( + exec_task_idx_2_1, task_idx_2_1, fake_actor_2, False ), - dag_idx_2_2: generate_dag_graph_nodes( - local_idx_2_2, dag_idx_2_2, fake_actor_2, False + task_idx_2_2: generate_dag_graph_nodes( + exec_task_idx_2_2, task_idx_2_2, fake_actor_2, False ), - dag_idx_1_2: generate_dag_graph_nodes( - local_idx_1_2, dag_idx_1_2, fake_actor_1, False + task_idx_1_2: generate_dag_graph_nodes( + exec_task_idx_1_2, task_idx_1_2, fake_actor_1, False ), } - self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) - self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_2]) - self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_1_2]) - self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) - self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) + self.add_edge_between_read_compute_write(graph[task_idx_1_1]) + self.add_edge_between_read_compute_write(graph[task_idx_1_2]) + self.add_edge_between_read_compute_write(graph[task_idx_2_1]) + self.add_edge_between_read_compute_write(graph[task_idx_2_2]) + self.add_data_dependeny(graph[task_idx_1_1], graph[task_idx_2_2]) + self.add_data_dependeny(graph[task_idx_2_1], graph[task_idx_1_2]) + self.add_control_dependency(graph[task_idx_1_1], graph[task_idx_1_2]) + self.add_control_dependency(graph[task_idx_2_1], graph[task_idx_2_2]) actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 @@ -703,64 +717,64 @@ def test_two_actors_no_nccl(self, monkeypatch): assert len(actor_to_execution_schedule[fake_actor_2]) == 6 assert actor_to_execution_schedule[fake_actor_1] == [ - graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.WRITE].operation, ] assert actor_to_execution_schedule[fake_actor_2] == [ - graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_2][_DAGNodeOperationType.WRITE].operation, ] def test_two_actors_with_nccl(self, monkeypatch): """ - driver -> actor_1.op (dag_idx_1_1) -> actor_2.op (dag_idx_2_2) -> driver + driver -> actor_1.op (task_idx_1_1) -> actor_2.op (task_idx_2_2) -> driver | | - -> actor_2.op (dag_idx_2_1) -> actor_1.op (dag_idx_1_2) - + -> actor_2.op (task_idx_2_1) -> actor_1.op (task_idx_1_2) - In this test, the communication between fake_actor_1 and fake_actor_2 is done - using NCCL. When the dag_idx_1.WRITE operation is picked, the dag_idx_2.READ + using NCCL. When the task_idx_1.WRITE operation is picked, the task_idx_2.READ operation is also added to the execution schedule because of the NCCL operation. """ monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) fake_actor_1 = ActorHandle("fake_actor_1") - dag_idx_1_1, local_idx_1_1 = 1, 0 - dag_idx_1_2, local_idx_1_2 = 4, 1 + task_idx_1_1, exec_task_idx_1_1 = 1, 0 + task_idx_1_2, exec_task_idx_1_2 = 4, 1 fake_actor_2 = ActorHandle("fake_actor_2") - dag_idx_2_1, local_idx_2_1 = 2, 0 - dag_idx_2_2, local_idx_2_2 = 3, 1 + task_idx_2_1, exec_task_idx_2_1 = 2, 0 + task_idx_2_2, exec_task_idx_2_2 = 3, 1 graph = { - dag_idx_1_1: generate_dag_graph_nodes( - local_idx_1_1, dag_idx_1_1, fake_actor_1, True + task_idx_1_1: generate_dag_graph_nodes( + exec_task_idx_1_1, task_idx_1_1, fake_actor_1, True ), - dag_idx_2_1: generate_dag_graph_nodes( - local_idx_2_1, dag_idx_2_1, fake_actor_2, True + task_idx_2_1: generate_dag_graph_nodes( + exec_task_idx_2_1, task_idx_2_1, fake_actor_2, True ), - dag_idx_2_2: generate_dag_graph_nodes( - local_idx_2_2, dag_idx_2_2, fake_actor_2, False + task_idx_2_2: generate_dag_graph_nodes( + exec_task_idx_2_2, task_idx_2_2, fake_actor_2, False ), - dag_idx_1_2: generate_dag_graph_nodes( - local_idx_1_2, dag_idx_1_2, fake_actor_1, False + task_idx_1_2: generate_dag_graph_nodes( + exec_task_idx_1_2, task_idx_1_2, fake_actor_1, False ), } - self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) - self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_2]) - self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_1_2]) - self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) - self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) + self.add_edge_between_read_compute_write(graph[task_idx_1_1]) + self.add_edge_between_read_compute_write(graph[task_idx_1_2]) + self.add_edge_between_read_compute_write(graph[task_idx_2_1]) + self.add_edge_between_read_compute_write(graph[task_idx_2_2]) + self.add_data_dependeny(graph[task_idx_1_1], graph[task_idx_2_2]) + self.add_data_dependeny(graph[task_idx_2_1], graph[task_idx_1_2]) + self.add_control_dependency(graph[task_idx_1_1], graph[task_idx_1_2]) + self.add_control_dependency(graph[task_idx_2_1], graph[task_idx_2_2]) actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 @@ -768,21 +782,21 @@ def test_two_actors_with_nccl(self, monkeypatch): assert len(actor_to_execution_schedule[fake_actor_2]) == 6 assert actor_to_execution_schedule[fake_actor_1] == [ - graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.WRITE].operation, ] assert actor_to_execution_schedule[fake_actor_2] == [ - graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, - # The order of `dag_idx_2_2.READ` and `dag_idx_2_2.COMPUTE` is important. - graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + # The order of `task_idx_2_2.READ` and `task_idx_2_2.COMPUTE` is important. + graph[task_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_2][_DAGNodeOperationType.WRITE].operation, ] def test_simulate_pp_2workers_2batches_1f1b_with_nccl(self, monkeypatch): @@ -799,94 +813,94 @@ def test_simulate_pp_2workers_2batches_1f1b_with_nccl(self, monkeypatch): monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) worker_1 = ActorHandle("worker_1") - dag_idx_1_1, local_idx_1_1 = 1, 0 - dag_idx_1_2, local_idx_1_2 = 2, 1 - dag_idx_1_3, local_idx_1_3 = 3, 2 - dag_idx_1_4, local_idx_1_4 = 4, 3 + task_idx_1_1, exec_task_idx_1_1 = 1, 0 + task_idx_1_2, exec_task_idx_1_2 = 2, 1 + task_idx_1_3, exec_task_idx_1_3 = 3, 2 + task_idx_1_4, exec_task_idx_1_4 = 4, 3 worker_2 = ActorHandle("worker_2") - dag_idx_2_1, local_idx_2_1 = 5, 0 - dag_idx_2_2, local_idx_2_2 = 6, 1 - dag_idx_2_3, local_idx_2_3 = 7, 2 - dag_idx_2_4, local_idx_2_4 = 8, 3 + task_idx_2_1, exec_task_idx_2_1 = 5, 0 + task_idx_2_2, exec_task_idx_2_2 = 6, 1 + task_idx_2_3, exec_task_idx_2_3 = 7, 2 + task_idx_2_4, exec_task_idx_2_4 = 8, 3 graph = { - dag_idx_1_1: generate_dag_graph_nodes( - local_idx_1_1, dag_idx_1_1, worker_1, True + task_idx_1_1: generate_dag_graph_nodes( + exec_task_idx_1_1, task_idx_1_1, worker_1, True ), - dag_idx_1_2: generate_dag_graph_nodes( - local_idx_1_2, dag_idx_1_2, worker_1, True + task_idx_1_2: generate_dag_graph_nodes( + exec_task_idx_1_2, task_idx_1_2, worker_1, True ), - dag_idx_1_3: generate_dag_graph_nodes( - local_idx_1_3, dag_idx_1_3, worker_1, False + task_idx_1_3: generate_dag_graph_nodes( + exec_task_idx_1_3, task_idx_1_3, worker_1, False ), - dag_idx_1_4: generate_dag_graph_nodes( - local_idx_1_4, dag_idx_1_4, worker_1, False + task_idx_1_4: generate_dag_graph_nodes( + exec_task_idx_1_4, task_idx_1_4, worker_1, False ), - dag_idx_2_1: generate_dag_graph_nodes( - local_idx_2_1, dag_idx_2_1, worker_2, False + task_idx_2_1: generate_dag_graph_nodes( + exec_task_idx_2_1, task_idx_2_1, worker_2, False ), - dag_idx_2_2: generate_dag_graph_nodes( - local_idx_2_2, dag_idx_2_2, worker_2, True + task_idx_2_2: generate_dag_graph_nodes( + exec_task_idx_2_2, task_idx_2_2, worker_2, True ), - dag_idx_2_3: generate_dag_graph_nodes( - local_idx_2_3, dag_idx_2_3, worker_2, False + task_idx_2_3: generate_dag_graph_nodes( + exec_task_idx_2_3, task_idx_2_3, worker_2, False ), - dag_idx_2_4: generate_dag_graph_nodes( - local_idx_2_4, dag_idx_2_4, worker_2, True + task_idx_2_4: generate_dag_graph_nodes( + exec_task_idx_2_4, task_idx_2_4, worker_2, True ), } - self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_3]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_4]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_3]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_4]) - self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_1]) - self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_2_2]) - self.add_data_dependeny(graph[dag_idx_2_2], graph[dag_idx_1_3]) - self.add_data_dependeny(graph[dag_idx_1_2], graph[dag_idx_2_3]) - self.add_data_dependeny(graph[dag_idx_2_3], graph[dag_idx_2_4]) - self.add_data_dependeny(graph[dag_idx_2_4], graph[dag_idx_1_4]) - self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) - self.add_control_dependency(graph[dag_idx_1_2], graph[dag_idx_1_3]) - self.add_control_dependency(graph[dag_idx_1_3], graph[dag_idx_1_4]) - self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) - self.add_control_dependency(graph[dag_idx_2_2], graph[dag_idx_2_3]) - self.add_control_dependency(graph[dag_idx_2_3], graph[dag_idx_2_4]) + self.add_edge_between_read_compute_write(graph[task_idx_1_1]) + self.add_edge_between_read_compute_write(graph[task_idx_1_2]) + self.add_edge_between_read_compute_write(graph[task_idx_1_3]) + self.add_edge_between_read_compute_write(graph[task_idx_1_4]) + self.add_edge_between_read_compute_write(graph[task_idx_2_1]) + self.add_edge_between_read_compute_write(graph[task_idx_2_2]) + self.add_edge_between_read_compute_write(graph[task_idx_2_3]) + self.add_edge_between_read_compute_write(graph[task_idx_2_4]) + self.add_data_dependeny(graph[task_idx_1_1], graph[task_idx_2_1]) + self.add_data_dependeny(graph[task_idx_2_1], graph[task_idx_2_2]) + self.add_data_dependeny(graph[task_idx_2_2], graph[task_idx_1_3]) + self.add_data_dependeny(graph[task_idx_1_2], graph[task_idx_2_3]) + self.add_data_dependeny(graph[task_idx_2_3], graph[task_idx_2_4]) + self.add_data_dependeny(graph[task_idx_2_4], graph[task_idx_1_4]) + self.add_control_dependency(graph[task_idx_1_1], graph[task_idx_1_2]) + self.add_control_dependency(graph[task_idx_1_2], graph[task_idx_1_3]) + self.add_control_dependency(graph[task_idx_1_3], graph[task_idx_1_4]) + self.add_control_dependency(graph[task_idx_2_1], graph[task_idx_2_2]) + self.add_control_dependency(graph[task_idx_2_2], graph[task_idx_2_3]) + self.add_control_dependency(graph[task_idx_2_3], graph[task_idx_2_4]) actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 assert len(actor_to_execution_schedule[worker_1]) == 12 assert len(actor_to_execution_schedule[worker_2]) == 12 assert actor_to_execution_schedule[worker_1] == [ - graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_3][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_3][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_3][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_4][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_4][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_4][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_3][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_3][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_3][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_4][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_4][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_4][_DAGNodeOperationType.WRITE].operation, ] assert actor_to_execution_schedule[worker_2] == [ - graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, - # The order of `dag_idx_2_3.READ` and `dag_idx_2_2.WRITE` is important. - graph[dag_idx_2_3][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_3][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_3][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_4][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_4][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_4][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + # The order of `task_idx_2_3.READ` and `task_idx_2_2.WRITE` is important. + graph[task_idx_2_3][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_3][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_3][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_4][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_4][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_4][_DAGNodeOperationType.WRITE].operation, ] def test_simulate_pp_2workers_2batches_1f1b_no_nccl(self, monkeypatch): @@ -904,97 +918,97 @@ def test_simulate_pp_2workers_2batches_1f1b_no_nccl(self, monkeypatch): monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) worker_1 = ActorHandle("worker_1") - dag_idx_1_1, local_idx_1_1 = 1, 0 - dag_idx_1_2, local_idx_1_2 = 2, 1 - dag_idx_1_3, local_idx_1_3 = 3, 2 - dag_idx_1_4, local_idx_1_4 = 4, 3 + task_idx_1_1, exec_task_idx_1_1 = 1, 0 + task_idx_1_2, exec_task_idx_1_2 = 2, 1 + task_idx_1_3, exec_task_idx_1_3 = 3, 2 + task_idx_1_4, exec_task_idx_1_4 = 4, 3 worker_2 = ActorHandle("worker_2") - dag_idx_2_1, local_idx_2_1 = 5, 0 - dag_idx_2_2, local_idx_2_2 = 6, 1 - dag_idx_2_3, local_idx_2_3 = 7, 2 - dag_idx_2_4, local_idx_2_4 = 8, 3 + task_idx_2_1, exec_task_idx_2_1 = 5, 0 + task_idx_2_2, exec_task_idx_2_2 = 6, 1 + task_idx_2_3, exec_task_idx_2_3 = 7, 2 + task_idx_2_4, exec_task_idx_2_4 = 8, 3 # No NCCL operation. graph = { - dag_idx_1_1: generate_dag_graph_nodes( - local_idx_1_1, dag_idx_1_1, worker_1, False + task_idx_1_1: generate_dag_graph_nodes( + exec_task_idx_1_1, task_idx_1_1, worker_1, False ), - dag_idx_1_2: generate_dag_graph_nodes( - local_idx_1_2, dag_idx_1_2, worker_1, False + task_idx_1_2: generate_dag_graph_nodes( + exec_task_idx_1_2, task_idx_1_2, worker_1, False ), - dag_idx_1_3: generate_dag_graph_nodes( - local_idx_1_3, dag_idx_1_3, worker_1, False + task_idx_1_3: generate_dag_graph_nodes( + exec_task_idx_1_3, task_idx_1_3, worker_1, False ), - dag_idx_1_4: generate_dag_graph_nodes( - local_idx_1_4, dag_idx_1_4, worker_1, False + task_idx_1_4: generate_dag_graph_nodes( + exec_task_idx_1_4, task_idx_1_4, worker_1, False ), - dag_idx_2_1: generate_dag_graph_nodes( - local_idx_2_1, dag_idx_2_1, worker_2, False + task_idx_2_1: generate_dag_graph_nodes( + exec_task_idx_2_1, task_idx_2_1, worker_2, False ), - dag_idx_2_2: generate_dag_graph_nodes( - local_idx_2_2, dag_idx_2_2, worker_2, False + task_idx_2_2: generate_dag_graph_nodes( + exec_task_idx_2_2, task_idx_2_2, worker_2, False ), - dag_idx_2_3: generate_dag_graph_nodes( - local_idx_2_3, dag_idx_2_3, worker_2, False + task_idx_2_3: generate_dag_graph_nodes( + exec_task_idx_2_3, task_idx_2_3, worker_2, False ), - dag_idx_2_4: generate_dag_graph_nodes( - local_idx_2_4, dag_idx_2_4, worker_2, False + task_idx_2_4: generate_dag_graph_nodes( + exec_task_idx_2_4, task_idx_2_4, worker_2, False ), } - self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_3]) - self.add_edge_between_read_compute_write(graph[dag_idx_1_4]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_3]) - self.add_edge_between_read_compute_write(graph[dag_idx_2_4]) - self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_1]) - self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_2_2]) - self.add_data_dependeny(graph[dag_idx_2_2], graph[dag_idx_1_3]) - self.add_data_dependeny(graph[dag_idx_1_2], graph[dag_idx_2_3]) - self.add_data_dependeny(graph[dag_idx_2_3], graph[dag_idx_2_4]) - self.add_data_dependeny(graph[dag_idx_2_4], graph[dag_idx_1_4]) - self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) - self.add_control_dependency(graph[dag_idx_1_2], graph[dag_idx_1_3]) - self.add_control_dependency(graph[dag_idx_1_3], graph[dag_idx_1_4]) - self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) - self.add_control_dependency(graph[dag_idx_2_2], graph[dag_idx_2_3]) - self.add_control_dependency(graph[dag_idx_2_3], graph[dag_idx_2_4]) + self.add_edge_between_read_compute_write(graph[task_idx_1_1]) + self.add_edge_between_read_compute_write(graph[task_idx_1_2]) + self.add_edge_between_read_compute_write(graph[task_idx_1_3]) + self.add_edge_between_read_compute_write(graph[task_idx_1_4]) + self.add_edge_between_read_compute_write(graph[task_idx_2_1]) + self.add_edge_between_read_compute_write(graph[task_idx_2_2]) + self.add_edge_between_read_compute_write(graph[task_idx_2_3]) + self.add_edge_between_read_compute_write(graph[task_idx_2_4]) + self.add_data_dependeny(graph[task_idx_1_1], graph[task_idx_2_1]) + self.add_data_dependeny(graph[task_idx_2_1], graph[task_idx_2_2]) + self.add_data_dependeny(graph[task_idx_2_2], graph[task_idx_1_3]) + self.add_data_dependeny(graph[task_idx_1_2], graph[task_idx_2_3]) + self.add_data_dependeny(graph[task_idx_2_3], graph[task_idx_2_4]) + self.add_data_dependeny(graph[task_idx_2_4], graph[task_idx_1_4]) + self.add_control_dependency(graph[task_idx_1_1], graph[task_idx_1_2]) + self.add_control_dependency(graph[task_idx_1_2], graph[task_idx_1_3]) + self.add_control_dependency(graph[task_idx_1_3], graph[task_idx_1_4]) + self.add_control_dependency(graph[task_idx_2_1], graph[task_idx_2_2]) + self.add_control_dependency(graph[task_idx_2_2], graph[task_idx_2_3]) + self.add_control_dependency(graph[task_idx_2_3], graph[task_idx_2_4]) actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) assert len(actor_to_execution_schedule) == 2 assert len(actor_to_execution_schedule[worker_1]) == 12 assert len(actor_to_execution_schedule[worker_2]) == 12 assert actor_to_execution_schedule[worker_1] == [ - graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_3][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_3][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_3][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_1_4][_DAGNodeOperationType.READ].operation, - graph[dag_idx_1_4][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_1_4][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_3][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_3][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_3][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_1_4][_DAGNodeOperationType.READ].operation, + graph[task_idx_1_4][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_1_4][_DAGNodeOperationType.WRITE].operation, ] assert actor_to_execution_schedule[worker_2] == [ - graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, - # The order of `dag_idx_2_3.READ` and `dag_idx_2_2.WRITE` is important. + graph[task_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + # The order of `task_idx_2_3.READ` and `task_idx_2_2.WRITE` is important. # It is different from the case where there is an NCCL operation. - graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_3][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_3][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_3][_DAGNodeOperationType.WRITE].operation, - graph[dag_idx_2_4][_DAGNodeOperationType.READ].operation, - graph[dag_idx_2_4][_DAGNodeOperationType.COMPUTE].operation, - graph[dag_idx_2_4][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_2][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_3][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_3][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_3][_DAGNodeOperationType.WRITE].operation, + graph[task_idx_2_4][_DAGNodeOperationType.READ].operation, + graph[task_idx_2_4][_DAGNodeOperationType.COMPUTE].operation, + graph[task_idx_2_4][_DAGNodeOperationType.WRITE].operation, ] diff --git a/python/ray/dashboard/state_aggregator.py b/python/ray/dashboard/state_aggregator.py index c4ba56c954c20..b934f4c5c9e68 100644 --- a/python/ray/dashboard/state_aggregator.py +++ b/python/ray/dashboard/state_aggregator.py @@ -65,7 +65,7 @@ def _convert_filters_type( filter: List[Tuple[str, PredicateType, SupportedFilterType]], schema: StateSchema, -) -> List[Tuple[str, SupportedFilterType]]: +) -> List[Tuple[str, PredicateType, SupportedFilterType]]: """Convert the given filter's type to SupportedFilterType. This method is necessary because click can only accept a single type @@ -155,7 +155,7 @@ def data_source_client(self): def _filter( self, data: List[dict], - filters: List[Tuple[str, SupportedFilterType]], + filters: List[Tuple[str, PredicateType, SupportedFilterType]], state_dataclass: StateSchema, detail: bool, ) -> List[dict]: @@ -181,6 +181,8 @@ def _filter( if filter_column not in filterable_columns: raise ValueError( f"The given filter column {filter_column} is not supported. " + "Enter filters with –-filter key=value " + "or –-filter key!=value " f"Supported filter columns: {filterable_columns}" ) diff --git a/python/ray/data/_internal/datasource/bigquery_datasink.py b/python/ray/data/_internal/datasource/bigquery_datasink.py index 7491540a5d73a..651216362ca8d 100644 --- a/python/ray/data/_internal/datasource/bigquery_datasink.py +++ b/python/ray/data/_internal/datasource/bigquery_datasink.py @@ -3,7 +3,7 @@ import tempfile import time import uuid -from typing import Any, Iterable, Optional +from typing import Iterable, Optional import pyarrow.parquet as pq @@ -70,7 +70,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: def _write_single_block(block: Block, project_id: str, dataset: str) -> None: from google.api_core import exceptions from google.cloud import bigquery @@ -127,5 +127,3 @@ def _write_single_block(block: Block, project_id: str, dataset: str) -> None: for block in blocks ] ) - - return "ok" diff --git a/python/ray/data/_internal/datasource/mongo_datasink.py b/python/ray/data/_internal/datasource/mongo_datasink.py index 5f731134f808e..78d56c81f0753 100644 --- a/python/ray/data/_internal/datasource/mongo_datasink.py +++ b/python/ray/data/_internal/datasource/mongo_datasink.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Iterable +from typing import Iterable from ray.data._internal.datasource.mongo_datasource import ( _validate_database_collection_exist, @@ -26,7 +26,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: import pymongo _validate_database_collection_exist( @@ -46,5 +46,3 @@ def write_block(uri: str, database: str, collection: str, block: Block): block = builder.build() write_block(self.uri, self.database, self.collection, block) - - return "ok" diff --git a/python/ray/data/_internal/datasource/parquet_datasink.py b/python/ray/data/_internal/datasource/parquet_datasink.py index 796b3f48c4ae4..4dffa939d7727 100644 --- a/python/ray/data/_internal/datasource/parquet_datasink.py +++ b/python/ray/data/_internal/datasource/parquet_datasink.py @@ -57,13 +57,13 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: import pyarrow.parquet as pq blocks = list(blocks) if all(BlockAccessor.for_block(block).num_rows() == 0 for block in blocks): - return "skip" + return filename = self.filename_provider.get_filename_for_block( blocks[0], ctx.task_idx, 0 @@ -90,8 +90,6 @@ def write_blocks_to_path(): max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS, ) - return "ok" - @property def num_rows_per_write(self) -> Optional[int]: return self.num_rows_per_file diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index 8b06ecfaad60f..b688c2630d686 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -230,6 +230,14 @@ def __init__( # duplicating the partition data, we disable PyArrow's partitioning. dataset_kwargs["partitioning"] = None + # `read_schema` is the schema object that will be used to perform + # read operations. + # It should be None, unless user has specified the schema or columns. + # We don't use the inferred schema for read, because the pyarrow only infers + # schema based on the first file. Thus, files with different schemas will end + # up producing blocks with wrong schema. + # See https://github.com/ray-project/ray/issues/47960 for more context. + read_schema = schema pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs) if schema is None: @@ -240,6 +248,7 @@ def __init__( schema = pa.schema( [schema.field(column) for column in columns], schema.metadata ) + read_schema = schema check_for_legacy_tensor_type(schema) @@ -247,17 +256,13 @@ def __init__( # Try to infer dataset schema by passing dummy table through UDF. dummy_table = schema.empty_table() try: - inferred_schema = _block_udf(dummy_table).schema - inferred_schema = inferred_schema.with_metadata(schema.metadata) + schema = _block_udf(dummy_table).schema.with_metadata(schema.metadata) except Exception: logger.debug( "Failed to infer schema of dataset by passing dummy table " "through UDF due to the following exception:", exc_info=True, ) - inferred_schema = schema - else: - inferred_schema = schema try: prefetch_remote_args = {} @@ -291,10 +296,10 @@ def __init__( self._pq_fragments = [SerializedFragment(p) for p in pq_ds.fragments] self._pq_paths = [p.path for p in pq_ds.fragments] self._meta_provider = meta_provider - self._inferred_schema = inferred_schema self._block_udf = _block_udf self._to_batches_kwargs = to_batch_kwargs self._columns = columns + self._read_schema = read_schema self._schema = schema self._file_metadata_shuffler = None self._include_paths = include_paths @@ -306,7 +311,7 @@ def __init__( self._pq_fragments, to_batches_kwargs=to_batch_kwargs, columns=columns, - schema=schema, + schema=self._read_schema, local_scheduling=self._local_scheduling, ) self._encoding_ratio = estimate_files_encoding_ratio(sample_infos) @@ -358,7 +363,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: meta = self._meta_provider( paths, - self._inferred_schema, + self._schema, num_fragments=len(fragments), prefetched_metadata=metadata, ) @@ -375,7 +380,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: to_batches_kwargs, default_read_batch_size_rows, columns, - schema, + read_schema, include_paths, partitioning, ) = ( @@ -383,7 +388,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: self._to_batches_kwargs, self._default_read_batch_size_rows, self._columns, - self._schema, + self._read_schema, self._include_paths, self._partitioning, ) @@ -394,7 +399,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: to_batches_kwargs, default_read_batch_size_rows, columns, - schema, + read_schema, f, include_paths, partitioning, diff --git a/python/ray/data/_internal/datasource/sql_datasink.py b/python/ray/data/_internal/datasource/sql_datasink.py index 5efd6edb79277..dbf49a145714c 100644 --- a/python/ray/data/_internal/datasource/sql_datasink.py +++ b/python/ray/data/_internal/datasource/sql_datasink.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable +from typing import Callable, Iterable from ray.data._internal.datasource.sql_datasource import Connection, _connect from ray.data._internal.execution.interfaces import TaskContext @@ -18,7 +18,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: with _connect(self.connection_factory) as cursor: for block in blocks: block_accessor = BlockAccessor.for_block(block) @@ -33,5 +33,3 @@ def write( if values: cursor.executemany(self.sql, values) - - return "ok" diff --git a/python/ray/data/_internal/execution/interfaces/physical_operator.py b/python/ray/data/_internal/execution/interfaces/physical_operator.py index fdb426264096b..be9bcc88ee8d5 100644 --- a/python/ray/data/_internal/execution/interfaces/physical_operator.py +++ b/python/ray/data/_internal/execution/interfaces/physical_operator.py @@ -12,7 +12,7 @@ ExecutionResources, ) from ray.data._internal.execution.interfaces.op_runtime_metrics import OpRuntimeMetrics -from ray.data._internal.logical.interfaces import Operator +from ray.data._internal.logical.interfaces import LogicalOperator, Operator from ray.data._internal.stats import StatsDict from ray.data.context import DataContext @@ -188,6 +188,9 @@ def __init__( self._estimated_num_output_bundles = None self._estimated_output_num_rows = None self._execution_completed = False + # The LogicalOperator(s) which were translated to create this PhysicalOperator. + # Set via `PhysicalOperator.set_logical_operators()`. + self._logical_operators: List[LogicalOperator] = [] def __reduce__(self): raise ValueError("Operator is not serializable.") @@ -205,6 +208,12 @@ def output_dependencies(self) -> List["PhysicalOperator"]: def post_order_iter(self) -> Iterator["PhysicalOperator"]: return super().post_order_iter() # type: ignore + def set_logical_operators( + self, + *logical_ops: LogicalOperator, + ): + self._logical_operators = list(logical_ops) + @property def target_max_block_size(self) -> Optional[int]: """ diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index a3d98bbeee3bd..6a42e0c760af1 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -1,6 +1,7 @@ import copy import functools import itertools +import logging from abc import ABC, abstractmethod from collections import defaultdict, deque from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Set, Union @@ -37,6 +38,8 @@ from ray.data.context import DataContext from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +logger = logging.getLogger(__name__) + class MapOperator(OneToOneOperator, ABC): """A streaming operator that maps input bundles 1:1 to output bundles. @@ -645,16 +648,16 @@ def _canonicalize_ray_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, and should not be a serious limitation for users. """ ray_remote_args = ray_remote_args.copy() + + if ray_remote_args.get("num_cpus") and ray_remote_args.get("num_gpus"): + logger.warning( + "Specifying both num_cpus and num_gpus for map tasks is experimental, " + "and may result in scheduling or stability issues. " + "Please report any issues to the Ray team: " + "https://github.com/ray-project/ray/issues/new/choose" + ) + if "num_cpus" not in ray_remote_args and "num_gpus" not in ray_remote_args: ray_remote_args["num_cpus"] = 1 - if ray_remote_args.get("num_gpus", 0) > 0: - if ray_remote_args.get("num_cpus", 0) != 0: - raise ValueError( - "It is not allowed to specify both num_cpus and num_gpus for map tasks." - ) - elif ray_remote_args.get("num_cpus", 0) > 0: - if ray_remote_args.get("num_gpus", 0) != 0: - raise ValueError( - "It is not allowed to specify both num_cpus and num_gpus for map tasks." - ) + return ray_remote_args diff --git a/python/ray/data/_internal/logging.py b/python/ray/data/_internal/logging.py index 0a9a5a8f7093a..f62357940ad38 100644 --- a/python/ray/data/_internal/logging.py +++ b/python/ray/data/_internal/logging.py @@ -11,6 +11,16 @@ os.path.join(os.path.dirname(__file__), "logging.yaml") ) +# Dictionary of substitutions to be performed when using JSON mode. Handlers with names +# corresponding to keys will be replaced by those corresponding to values. +RAY_DATA_LOG_HANDLER_JSON_SUBSTITUTIONS = {"file": "file_json"} + +# Env. variable to specify the encoding of the file logs when using the default config. +RAY_DATA_LOG_ENCODING_ENV_VAR_NAME = "RAY_DATA_LOG_ENCODING" + +# Env. variable to specify the logging config path use defaults if not set +RAY_DATA_LOGGING_CONFIG_ENV_VAR_NAME = "RAY_DATA_LOGGING_CONFIG" + # To facilitate debugging, Ray Data writes debug logs to a file. However, if Ray Data # logs every scheduler loop, logging might impact performance. So, we add a "TRACE" # level where logs aren't written by default. @@ -89,15 +99,47 @@ def _try_create_handler(self): def configure_logging() -> None: """Configure the Python logger named 'ray.data'. - This function loads the configration YAML specified by the "RAY_DATA_LOGGING_CONFIG" - environment variable. If the variable isn't set, this function loads the + This function loads the configration YAML specified by "RAY_DATA_LOGGING_CONFIG" + environment variable. If the variable isn't set, this function loads the default "logging.yaml" file that is adjacent to this module. + + If "RAY_DATA_LOG_ENCODING" is specified as "JSON" we will enable JSON logging mode + if using the default logging config. """ - config_path = os.environ.get("RAY_DATA_LOGGING_CONFIG", DEFAULT_CONFIG_PATH) - with open(config_path) as file: - config = yaml.safe_load(file) + + def _load_logging_config(config_path: str): + with open(config_path) as file: + config = yaml.safe_load(file) + return config + + # Dynamically load env vars + config_path = os.environ.get(RAY_DATA_LOGGING_CONFIG_ENV_VAR_NAME) + log_encoding = os.environ.get(RAY_DATA_LOG_ENCODING_ENV_VAR_NAME) + + if config_path is not None: + config = _load_logging_config(config_path) + else: + config = _load_logging_config(DEFAULT_CONFIG_PATH) + if log_encoding is not None and log_encoding.upper() == "JSON": + for logger in config["loggers"].values(): + for ( + old_handler_name, + new_handler_name, + ) in RAY_DATA_LOG_HANDLER_JSON_SUBSTITUTIONS.items(): + logger["handlers"].remove(old_handler_name) + logger["handlers"].append(new_handler_name) + logging.config.dictConfig(config) + # After configuring logger, warn if RAY_DATA_LOGGING_CONFIG is used with + # RAY_DATA_LOG_ENCODING, because they are not both supported together. + if config_path is not None and log_encoding is not None: + logger = logging.getLogger(__name__) + logger.warning( + "Using `RAY_DATA_LOG_ENCODING` is not supported with " + + "`RAY_DATA_LOGGING_CONFIG`" + ) + def reset_logging() -> None: """Reset the logger named 'ray.data' to its initial state. diff --git a/python/ray/data/_internal/logging.yaml b/python/ray/data/_internal/logging.yaml index f72abf356f6dc..170d7c6605d85 100644 --- a/python/ray/data/_internal/logging.yaml +++ b/python/ray/data/_internal/logging.yaml @@ -4,16 +4,25 @@ disable_existing_loggers: False formatters: ray: format: "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s" + ray_json: + class: ray._private.ray_logging.formatters.JSONFormatter filters: console_filter: (): ray.data._internal.logging.HiddenRecordFilter + core_context_filter: + (): ray._private.ray_logging.filters.CoreContextFilter handlers: file: class: ray.data._internal.logging.SessionFileHandler formatter: ray filename: ray-data.log + file_json: + class: ray.data._internal.logging.SessionFileHandler + formatter: ray_json + filename: ray-data.log + filters: [core_context_filter] console: class: ray._private.log.PlainRayHandler formatter: ray diff --git a/python/ray/data/_internal/logical/optimizers.py b/python/ray/data/_internal/logical/optimizers.py index 8371d70b70fcf..3c89d658f1511 100644 --- a/python/ray/data/_internal/logical/optimizers.py +++ b/python/ray/data/_internal/logical/optimizers.py @@ -1,4 +1,4 @@ -from typing import List, Type +from typing import List, Optional, Type from ray.data._internal.logical.interfaces import ( LogicalPlan, @@ -31,13 +31,19 @@ @DeveloperAPI -def register_logical_rule(cls: Type[Rule]): - _LOGICAL_RULES.append(cls) +def register_logical_rule(cls: Type[Rule], insert_index: Optional[int] = None): + if insert_index is None: + _LOGICAL_RULES.append(cls) + else: + _LOGICAL_RULES.insert(insert_index, cls) @DeveloperAPI -def register_physical_rule(cls: Type[Rule]): - _PHYSICAL_RULES.append(cls) +def register_physical_rule(cls: Type[Rule], insert_index: Optional[int] = None): + if insert_index is None: + _PHYSICAL_RULES.append(cls) + else: + _PHYSICAL_RULES.insert(insert_index, cls) class LogicalOptimizer(Optimizer): diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 30fdb2dc4fa11..ab6730bc63dcd 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -325,6 +325,7 @@ def _get_fused_map_operator( ray_remote_args=ray_remote_args, ray_remote_args_fn=ray_remote_args_fn, ) + op.set_logical_operators(*up_op._logical_operators, *down_op._logical_operators) # Build a map logical operator to be used as a reference for further fusion. # TODO(Scott): This is hacky, remove this once we push fusion to be purely based diff --git a/python/ray/data/_internal/planner/plan_write_op.py b/python/ray/data/_internal/planner/plan_write_op.py index 94d1d744b9f9b..69c35df1b6c13 100644 --- a/python/ray/data/_internal/planner/plan_write_op.py +++ b/python/ray/data/_internal/planner/plan_write_op.py @@ -1,3 +1,4 @@ +import itertools from typing import Callable, Iterator, List, Union from ray.data._internal.compute import TaskPoolStrategy @@ -9,32 +10,49 @@ MapTransformer, ) from ray.data._internal.logical.operators.write_operator import Write -from ray.data.block import Block -from ray.data.datasource.datasink import Datasink +from ray.data.block import Block, BlockAccessor +from ray.data.datasource.datasink import Datasink, WriteResult from ray.data.datasource.datasource import Datasource def generate_write_fn( datasink_or_legacy_datasource: Union[Datasink, Datasource], **write_args ) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]: - # If the write op succeeds, the resulting Dataset is a list of - # arbitrary objects (one object per write task). Otherwise, an error will - # be raised. The Datasource can handle execution outcomes with the - # on_write_complete() and on_write_failed(). def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: + """Writes the blocks to the given datasink or legacy datasource. + + Outputs the original blocks to be written.""" + # Create a copy of the iterator, so we can return the original blocks. + it1, it2 = itertools.tee(blocks, 2) if isinstance(datasink_or_legacy_datasource, Datasink): - write_result = datasink_or_legacy_datasource.write(blocks, ctx) + datasink_or_legacy_datasource.write(it1, ctx) else: - write_result = datasink_or_legacy_datasource.write( - blocks, ctx, **write_args - ) + datasink_or_legacy_datasource.write(it1, ctx, **write_args) + return it2 + + return fn + + +def generate_collect_write_stats_fn() -> Callable[ + [Iterator[Block], TaskContext], Iterator[Block] +]: + # If the write op succeeds, the resulting Dataset is a list of + # one Block which contain stats/metrics about the write. + # Otherwise, an error will be raised. The Datasource can handle + # execution outcomes with `on_write_complete()`` and `on_write_failed()``. + def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: + """Handles stats collection for block writes.""" + block_accessors = [BlockAccessor.for_block(block) for block in blocks] + total_num_rows = sum(ba.num_rows() for ba in block_accessors) + total_size_bytes = sum(ba.size_bytes() for ba in block_accessors) # NOTE: Write tasks can return anything, so we need to wrap it in a valid block # type. import pandas as pd + write_result = WriteResult(num_rows=total_num_rows, size_bytes=total_size_bytes) block = pd.DataFrame({"write_result": [write_result]}) - return [block] + return iter([block]) return fn @@ -46,9 +64,11 @@ def plan_write_op( input_physical_dag = physical_children[0] write_fn = generate_write_fn(op._datasink_or_legacy_datasource, **op._write_args) + collect_stats_fn = generate_collect_write_stats_fn() # Create a MapTransformer for a write operator transform_fns = [ BlockMapTransformFn(write_fn), + BlockMapTransformFn(collect_stats_fn), ] map_transformer = MapTransformer(transform_fns) return MapOperator.create( diff --git a/python/ray/data/_internal/planner/planner.py b/python/ray/data/_internal/planner/planner.py index d47afc43b7158..3cc97f4db6cb0 100644 --- a/python/ray/data/_internal/planner/planner.py +++ b/python/ray/data/_internal/planner/planner.py @@ -127,5 +127,18 @@ def _plan(self, logical_op: LogicalOperator) -> PhysicalOperator: f"Found unknown logical operator during planning: {logical_op}" ) + # Traverse up the DAG, and set the mapping from physical to logical operators. + # At this point, all physical operators without logical operators set + # must have been created by the current logical operator. + queue = [physical_op] + while queue: + curr_physical_op = queue.pop() + # Once we find an operator with a logical operator set, we can stop. + if curr_physical_op._logical_operators: + break + + curr_physical_op.set_logical_operators(logical_op) + queue.extend(physical_op.input_dependencies) + self._physical_op_to_logical_op[physical_op] = logical_op return physical_op diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 94c96a6a97210..ea0efdca8a14f 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -272,6 +272,12 @@ def map( If your transformation is vectorized like most NumPy or pandas operations, :meth:`~Dataset.map_batches` might be faster. + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + Examples: .. testcode:: @@ -417,6 +423,12 @@ def map_batches( If ``fn`` doesn't mutate its input, set ``zero_copy_batch=True`` to improve performance and decrease memory utilization. + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + Examples: Call :meth:`~Dataset.map_batches` to transform your data. @@ -973,6 +985,12 @@ def flat_map( transformation is vectorized like most NumPy and pandas operations, it might be faster. + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + Examples: .. testcode:: @@ -3709,13 +3727,15 @@ def write_datasink( datasink.on_write_start() self._write_ds = Dataset(plan, logical_plan).materialize() - blocks = ray.get(self._write_ds._plan.execute().block_refs) + # TODO: Get and handle the blocks with an iterator instead of getting + # everything in a blocking way, so some blocks can be freed earlier. + raw_write_results = ray.get(self._write_ds._plan.execute().block_refs) assert all( - isinstance(block, pd.DataFrame) and len(block) == 1 for block in blocks + isinstance(block, pd.DataFrame) and len(block) == 1 + for block in raw_write_results ) - write_results = [block["write_result"][0] for block in blocks] + datasink.on_write_complete(raw_write_results) - datasink.on_write_complete(write_results) except Exception as e: datasink.on_write_failed(e) raise diff --git a/python/ray/data/datasource/datasink.py b/python/ray/data/datasource/datasink.py index c523b5cd06c0c..0832e0539fd16 100644 --- a/python/ray/data/datasource/datasink.py +++ b/python/ray/data/datasource/datasink.py @@ -1,10 +1,51 @@ -from typing import Any, Iterable, List, Optional +import logging +from dataclasses import dataclass, fields +from typing import Iterable, List, Optional import ray from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor from ray.util.annotations import DeveloperAPI +logger = logging.getLogger(__name__) + + +@dataclass +@DeveloperAPI +class WriteResult: + """Result of a write operation, containing stats/metrics + on the written data. + + Attributes: + total_num_rows: The total number of rows written. + total_size_bytes: The total size of the written data in bytes. + """ + + num_rows: int = 0 + size_bytes: int = 0 + + @staticmethod + def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult": + """Aggregate a list of write results. + + Args: + write_results: A list of write results. + + Returns: + A single write result that aggregates the input results. + """ + total_num_rows = 0 + total_size_bytes = 0 + + for write_result in write_results: + total_num_rows += write_result.num_rows + total_size_bytes += write_result.size_bytes + + return WriteResult( + num_rows=total_num_rows, + size_bytes=total_size_bytes, + ) + @DeveloperAPI class Datasink: @@ -26,20 +67,16 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: """Write blocks. This is used by a single write task. Args: blocks: Generator of data blocks. ctx: ``TaskContext`` for the write task. - - Returns: - A user-defined output. Can be anything, and the returned value is passed to - :meth:`~ray.data.Datasink.on_write_complete`. """ raise NotImplementedError - def on_write_complete(self, write_results: List[Any]) -> None: + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: """Callback for when a write job completes. This can be used to "commit" a write output. This method must @@ -47,10 +84,27 @@ def on_write_complete(self, write_results: List[Any]) -> None: method fails, then ``on_write_failed()`` is called. Args: - write_results: The objects returned by every - :meth:`~ray.data.Datasink.write` task. + write_result_blocks: The blocks resulting from executing + the Write operator, containing write results and stats. + Returns: + A ``WriteResult`` object containing the aggregated stats of all + the input write results. """ - pass + write_results = [ + result["write_result"].iloc[0] for result in write_result_blocks + ] + aggregated_write_results = WriteResult.aggregate_write_results(write_results) + + aggregated_results_str = "" + for k in fields(aggregated_write_results.__class__): + v = getattr(aggregated_write_results, k.name) + aggregated_results_str += f"\t{k}: {v}\n" + + logger.info( + f"Write operation succeeded. Aggregated write results:\n" + f"{aggregated_results_str}" + ) + return aggregated_write_results def on_write_failed(self, error: Exception) -> None: """Callback for when a write job fails. @@ -111,10 +165,9 @@ def __init__(self): self.rows_written = 0 self.enabled = True - def write(self, block: Block) -> str: + def write(self, block: Block) -> None: block = BlockAccessor.for_block(block) self.rows_written += block.num_rows() - return "ok" def get_rows_written(self): return self.rows_written @@ -128,18 +181,18 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: tasks = [] if not self.enabled: raise ValueError("disabled") for b in blocks: tasks.append(self.data_sink.write.remote(b)) ray.get(tasks) - return "ok" - def on_write_complete(self, write_results: List[Any]) -> None: - assert all(w == "ok" for w in write_results), write_results + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: self.num_ok += 1 + aggregated_results = super().on_write_complete(write_result_blocks) + return aggregated_results def on_write_failed(self, error: Exception) -> None: self.num_failed += 1 diff --git a/python/ray/data/datasource/file_datasink.py b/python/ray/data/datasource/file_datasink.py index b09090f513af2..79d106f39ba3d 100644 --- a/python/ray/data/datasource/file_datasink.py +++ b/python/ray/data/datasource/file_datasink.py @@ -9,7 +9,7 @@ from ray.data._internal.util import _is_local_scheme, call_with_retry from ray.data.block import Block, BlockAccessor from ray.data.context import DataContext -from ray.data.datasource.datasink import Datasink +from ray.data.datasource.datasink import Datasink, WriteResult from ray.data.datasource.filename_provider import ( FilenameProvider, _DefaultFilenameProvider, @@ -114,7 +114,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: builder = DelegatingBlockBuilder() for block in blocks: builder.add_block(block) @@ -123,22 +123,20 @@ def write( if block_accessor.num_rows() == 0: logger.warning(f"Skipped writing empty block to {self.path}") - return "skip" + return self.write_block(block_accessor, 0, ctx) - # TODO: decide if we want to return richer object when the task - # succeeds. - return "ok" def write_block(self, block: BlockAccessor, block_index: int, ctx: TaskContext): raise NotImplementedError - def on_write_complete(self, write_results: List[Any]) -> None: - if not self.has_created_dir: - return + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: + aggregated_results = super().on_write_complete(write_result_blocks) - if all(write_results == "skip" for write_results in write_results): + # If no rows were written, we can delete the directory. + if self.has_created_dir and aggregated_results.num_rows == 0: self.filesystem.delete_dir(self.path) + return aggregated_results @property def supports_distributed_writes(self) -> bool: diff --git a/python/ray/data/grouped_data.py b/python/ray/data/grouped_data.py index c76d6cee7615f..e479908136a29 100644 --- a/python/ray/data/grouped_data.py +++ b/python/ray/data/grouped_data.py @@ -129,6 +129,12 @@ def map_groups( In general, prefer to use aggregate() instead of map_groups(). + .. warning:: + Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, + and may result in scheduling or stability issues. Please + `report any issues `_ + to the Ray team. + Examples: >>> # Return a single record per group (list of multiple records in, >>> # list of a single record out). diff --git a/python/ray/data/tests/test_bigquery.py b/python/ray/data/tests/test_bigquery.py index 325de3eaa586d..67266a822c425 100644 --- a/python/ray/data/tests/test_bigquery.py +++ b/python/ray/data/tests/test_bigquery.py @@ -1,3 +1,4 @@ +from typing import Iterator from unittest import mock import pyarrow as pa @@ -10,6 +11,9 @@ import ray from ray.data._internal.datasource.bigquery_datasink import BigQueryDatasink from ray.data._internal.datasource.bigquery_datasource import BigQueryDatasource +from ray.data._internal.planner.plan_write_op import generate_collect_write_stats_fn +from ray.data.block import Block +from ray.data.datasource.datasink import WriteResult from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa @@ -197,6 +201,9 @@ def test_create_reader_table_not_found(self): class TestWriteBigQuery: """Tests for BigQuery Write.""" + def _extract_write_result(self, stats: Iterator[Block]): + return dict(next(stats).iloc[0])["write_result"] + def test_write(self, ray_get_mock): bq_datasink = BigQueryDatasink( project_id=_TEST_GCP_PROJECT_ID, @@ -204,11 +211,15 @@ def test_write(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) - status = bq_datasink.write( + bq_datasink.write( blocks=[block], ctx=None, ) - assert status == "ok" + + collect_stats_fn = generate_collect_write_stats_fn() + stats = collect_stats_fn([block], None) + write_result = self._extract_write_result(stats) + assert write_result == WriteResult(num_rows=4, size_bytes=32) def test_write_dataset_exists(self, ray_get_mock): bq_datasink = BigQueryDatasink( @@ -217,11 +228,14 @@ def test_write_dataset_exists(self, ray_get_mock): ) arr = pa.array([2, 4, 5, 100]) block = pa.Table.from_arrays([arr], names=["data"]) - status = bq_datasink.write( + bq_datasink.write( blocks=[block], ctx=None, ) - assert status == "ok" + collect_stats_fn = generate_collect_write_stats_fn() + stats = collect_stats_fn([block], None) + write_result = self._extract_write_result(stats) + assert write_result == WriteResult(num_rows=4, size_bytes=32) if __name__ == "__main__": diff --git a/python/ray/data/tests/test_datasink.py b/python/ray/data/tests/test_datasink.py index 7720784906017..714f03c6dfe30 100644 --- a/python/ray/data/tests/test_datasink.py +++ b/python/ray/data/tests/test_datasink.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable +from typing import Iterable import pytest @@ -14,7 +14,7 @@ class MockDatasink(Datasink): def __init__(self, num_rows_per_write): self._num_rows_per_write = num_rows_per_write - def write(self, blocks: Iterable[Block], ctx: TaskContext) -> Any: + def write(self, blocks: Iterable[Block], ctx: TaskContext) -> None: assert sum(len(block) for block in blocks) == self._num_rows_per_write @property diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 913c7ee1822bf..04da5508b1fbf 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -111,6 +111,9 @@ def test_read_operator(ray_start_regular_shared): physical_op.actual_target_max_block_size == DataContext.get_current().target_max_block_size ) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + assert physical_op.input_dependencies[0]._logical_operators == [op] def test_read_operator_emits_warning_for_large_read_tasks(): @@ -182,6 +185,9 @@ def test_from_operators(ray_start_regular_shared): assert isinstance(physical_op, InputDataBuffer) assert len(physical_op.input_dependencies) == 0 + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + def test_from_items_e2e(ray_start_regular_shared): data = ["Hello", "World"] @@ -253,6 +259,9 @@ def test_map_batches_operator(ray_start_regular_shared): assert len(physical_op.input_dependencies) == 1 assert isinstance(physical_op.input_dependencies[0], MapOperator) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + def test_map_batches_e2e(ray_start_regular_shared): ds = ray.data.range(5) @@ -393,6 +402,9 @@ def test_random_shuffle_operator(ray_start_regular_shared): == DataContext.get_current().target_shuffle_max_block_size ) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + def test_random_shuffle_e2e(ray_start_regular_shared, use_push_based_shuffle): ds = ray.data.range(12, override_num_blocks=4) @@ -430,6 +442,9 @@ def test_repartition_operator(ray_start_regular_shared, shuffle): == DataContext.get_current().target_max_block_size ) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + @pytest.mark.parametrize( "shuffle", @@ -506,6 +521,9 @@ def test_union_operator(ray_start_regular_shared, preserve_order): == DataContext.get_current().target_max_block_size ) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [union_op] + @pytest.mark.parametrize("preserve_order", (True, False)) def test_union_e2e(ray_start_regular_shared, preserve_order): @@ -578,22 +596,23 @@ def test_read_map_batches_operator_fusion(ray_start_regular_shared): physical_op.actual_target_max_block_size == DataContext.get_current().target_max_block_size ) + assert physical_op._logical_operators == [read_op, op] def test_read_map_chain_operator_fusion(ray_start_regular_shared): # Test that a chain of different map operators are fused. planner = Planner() read_op = get_parquet_read_logical_op(parallelism=1) - op = MapRows(read_op, lambda x: x) - op = MapBatches(op, lambda x: x) - op = FlatMap(op, lambda x: x) - op = Filter(op, lambda x: x) - logical_plan = LogicalPlan(op) + map1 = MapRows(read_op, lambda x: x) + map2 = MapBatches(map1, lambda x: x) + map3 = FlatMap(map2, lambda x: x) + map4 = Filter(map3, lambda x: x) + logical_plan = LogicalPlan(map4) physical_plan = planner.plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) physical_op = physical_plan.dag - assert op.name == "Filter()" + assert map4.name == "Filter()" assert ( physical_op.name == "ReadParquet->Map()->MapBatches()" "->FlatMap()->Filter()" @@ -605,6 +624,7 @@ def test_read_map_chain_operator_fusion(ray_start_regular_shared): physical_op.actual_target_max_block_size == DataContext.get_current().target_max_block_size ) + assert physical_op._logical_operators == [read_op, map1, map2, map3, map4] def test_read_map_batches_operator_fusion_compatible_remote_args( @@ -1009,6 +1029,9 @@ def test_write_operator(ray_start_regular_shared, tmp_path): assert len(physical_op.input_dependencies) == 1 assert isinstance(physical_op.input_dependencies[0], MapOperator) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + def test_sort_operator( ray_start_regular_shared, @@ -1105,6 +1128,9 @@ def test_aggregate_operator(ray_start_regular_shared): == DataContext.get_current().target_shuffle_max_block_size ) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + def test_aggregate_e2e(ray_start_regular_shared, use_push_based_shuffle): ds = ray.data.range(100, override_num_blocks=4) @@ -1171,6 +1197,9 @@ def test_zip_operator(ray_start_regular_shared): == DataContext.get_current().target_max_block_size ) + # Check that the linked logical operator is the same the input op. + assert physical_op._logical_operators == [op] + @pytest.mark.parametrize( "num_blocks1,num_blocks2", @@ -1620,5 +1649,88 @@ def test_zero_copy_fusion_eliminate_build_output_blocks(ray_start_regular_shared ) +def test_insert_logical_optimization_rules(): + class FakeRule1: + pass + + class FakeRule2: + pass + + from ray.data._internal.logical.optimizers import ( + _LOGICAL_RULES, + register_logical_rule, + ) + from ray.data._internal.logical.rules.randomize_blocks import ( + ReorderRandomizeBlocksRule, + ) + + register_logical_rule(FakeRule1) + assert _LOGICAL_RULES == [ReorderRandomizeBlocksRule, FakeRule1] + + register_logical_rule(FakeRule2, 1) + assert _LOGICAL_RULES == [ReorderRandomizeBlocksRule, FakeRule2, FakeRule1] + + register_logical_rule(FakeRule1, 0) + assert _LOGICAL_RULES == [ + FakeRule1, + ReorderRandomizeBlocksRule, + FakeRule2, + FakeRule1, + ] + + +def test_insert_physical_optimization_rules(): + class FakeRule1: + pass + + class FakeRule2: + pass + + from ray.data._internal.logical.optimizers import ( + _PHYSICAL_RULES, + register_physical_rule, + ) + from ray.data._internal.logical.rules.inherit_target_max_block_size import ( + InheritTargetMaxBlockSizeRule, + ) + from ray.data._internal.logical.rules.operator_fusion import OperatorFusionRule + from ray.data._internal.logical.rules.set_read_parallelism import ( + SetReadParallelismRule, + ) + from ray.data._internal.logical.rules.zero_copy_map_fusion import ( + EliminateBuildOutputBlocks, + ) + + register_physical_rule(FakeRule1) + assert _PHYSICAL_RULES == [ + InheritTargetMaxBlockSizeRule, + SetReadParallelismRule, + OperatorFusionRule, + EliminateBuildOutputBlocks, + FakeRule1, + ] + + register_physical_rule(FakeRule2, 2) + assert _PHYSICAL_RULES == [ + InheritTargetMaxBlockSizeRule, + SetReadParallelismRule, + FakeRule2, + OperatorFusionRule, + EliminateBuildOutputBlocks, + FakeRule1, + ] + + register_physical_rule(FakeRule1, 0) + assert _PHYSICAL_RULES == [ + FakeRule1, + InheritTargetMaxBlockSizeRule, + SetReadParallelismRule, + FakeRule2, + OperatorFusionRule, + EliminateBuildOutputBlocks, + FakeRule1, + ] + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/test_executor_resource_management.py b/python/ray/data/tests/test_executor_resource_management.py index d4d73956d6840..8bf3984ba7ccb 100644 --- a/python/ray/data/tests/test_executor_resource_management.py +++ b/python/ray/data/tests/test_executor_resource_management.py @@ -123,14 +123,18 @@ def test_resource_canonicalization(ray_start_10_cpus_shared): ) assert op._ray_remote_args == {"num_gpus": 2} - with pytest.raises(ValueError): - MapOperator.create( - _mul2_map_data_prcessor, - input_op=input_op, - name="TestMapper", - compute_strategy=TaskPoolStrategy(), - ray_remote_args={"num_gpus": 2, "num_cpus": 1}, - ) + op = MapOperator.create( + _mul2_map_data_prcessor, + input_op=input_op, + name="TestMapper", + compute_strategy=TaskPoolStrategy(), + ray_remote_args={"num_gpus": 2, "num_cpus": 1}, + ) + assert op.base_resource_usage() == ExecutionResources() + assert op.incremental_resource_usage() == ExecutionResources( + cpu=1, gpu=2, object_store_memory=inc_obj_store_mem + ) + assert op._ray_remote_args == {"num_gpus": 2, "num_cpus": 1} def test_execution_options_resource_limit(): diff --git a/python/ray/data/tests/test_formats.py b/python/ray/data/tests/test_formats.py index 943e7d19bcffd..52cc95fc0335a 100644 --- a/python/ray/data/tests/test_formats.py +++ b/python/ray/data/tests/test_formats.py @@ -1,6 +1,6 @@ import os import sys -from typing import Any, Iterable, List +from typing import Iterable, List import pandas as pd import pyarrow as pa @@ -14,6 +14,7 @@ from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor from ray.data.datasource import Datasink, DummyOutputDatasink +from ray.data.datasource.datasink import WriteResult from ray.data.datasource.file_meta_provider import _handle_read_os_error from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa @@ -239,7 +240,6 @@ def write(self, node_id: str, block: Block) -> str: block = BlockAccessor.for_block(block) self.rows_written += block.num_rows() self.node_ids.add(node_id) - return "ok" def get_rows_written(self): return self.rows_written @@ -255,7 +255,7 @@ def write( self, blocks: Iterable[Block], ctx: TaskContext, - ) -> Any: + ) -> None: data_sink = self.data_sink def write(b): @@ -266,11 +266,11 @@ def write(b): for b in blocks: tasks.append(write(b)) ray.get(tasks) - return "ok" - def on_write_complete(self, write_results: List[Any]) -> None: - assert all(w == "ok" for w in write_results), write_results + def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult: self.num_ok += 1 + aggregated_results = super().on_write_complete(write_result_blocks) + return aggregated_results def on_write_failed(self, error: Exception) -> None: self.num_failed += 1 diff --git a/python/ray/data/tests/test_logging.py b/python/ray/data/tests/test_logging.py index de18e44523845..bc6c0ecc0c037 100644 --- a/python/ray/data/tests/test_logging.py +++ b/python/ray/data/tests/test_logging.py @@ -124,6 +124,49 @@ def test_custom_config(reset_logging, monkeypatch, tmp_path): assert isinstance(logger.handlers[0], logging.StreamHandler) +def test_json_logging_configuration( + capsys, reset_logging, monkeypatch, shutdown_only, propagate_logs +): + import json + + monkeypatch.setenv("RAY_DATA_LOG_ENCODING", "JSON") + ray.init() + + configure_logging() + + logger = logging.getLogger("ray.data") + + # Ensure handlers correctly setup + handlers = logger.handlers + assert sum(handler.name == "file_json" for handler in handlers) == 1 + assert sum(handler.name == "console" for handler in handlers) == 1 + + logger.info("ham") + logger.debug("turkey") + + log_path = os.path.join(get_log_directory(), "ray-data.log") + with open(log_path) as file: + log_contents = file.read() + + # Validate the log is in JSON format (a basic check for JSON) + messages = [] + for log_line in log_contents.splitlines(): + log_dict = json.loads(log_line) # will error if not a json line + messages.append(log_dict["message"]) + + assert "ham" in messages + assert "turkey" in messages + + # Validate console logs are in text mode + console_log_output = capsys.readouterr().err + for log_line in console_log_output.splitlines(): + with pytest.raises(json.JSONDecodeError): + json.loads(log_line) + + assert "ham" in console_log_output + assert "turkey" not in console_log_output + + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 15629e5ef0f79..23969d736f046 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -1291,6 +1291,32 @@ def _assert_equal(rows, expected): _assert_equal(ds.take_all(), expected_tuples) +def test_multiple_files_with_ragged_arrays(ray_start_regular_shared, tmp_path): + # Test reading multiple parquet files, each of which has different-shaped + # ndarrays in the same column. + # See https://github.com/ray-project/ray/issues/47960 for more context. + num_rows = 3 + ds = ray.data.range(num_rows) + + def map(row): + id = row["id"] + 1 + row["data"] = np.zeros((id * 100, id * 100), dtype=np.int8) + return row + + # Write 3 parquet files with different-shaped ndarray values in the + # "data" column. + ds.map(map).repartition(num_rows).write_parquet(tmp_path) + + # Read these 3 files, check that the result is correct. + ds2 = ray.data.read_parquet(tmp_path, override_num_blocks=1) + res = ds2.take_all() + res = sorted(res, key=lambda row: row["id"]) + assert len(res) == num_rows + for index, item in enumerate(res): + assert item["id"] == index + assert item["data"].shape == (100 * (index + 1), 100 * (index + 1)) + + if __name__ == "__main__": import sys diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index d69cde6a283dc..33bb7a225c79e 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -1,7 +1,8 @@ import logging import os +import sys from traceback import format_exception -from typing import Optional, Type, Union +from typing import Optional, Union import colorama @@ -147,9 +148,21 @@ def __init__( assert traceback_str is not None - def make_dual_exception_type(self) -> Type: - """Makes a Type that inherits from both RayTaskError and the type of + def make_dual_exception_instance(self) -> "RayTaskError": + """Makes a object instance that inherits from both RayTaskError and the type of `self.cause`. Raises TypeError if the cause class can't be subclassed""" + # For normal user Exceptions, we subclass from both + # RayTaskError and the user exception. For ExceptionGroup, + # we special handle it because it has a different __new__() + # signature from Exception. + # Ref: https://docs.python.org/3/library/exceptions.html#exception-groups + if sys.version_info >= (3, 11) and isinstance( + self.cause, ExceptionGroup # noqa: F821 + ): + return self._make_exceptiongroup_dual_exception_instance() + return self._make_normal_dual_exception_instance() + + def _make_normal_dual_exception_instance(self) -> "RayTaskError": cause_cls = self.cause.__class__ error_msg = str(self) @@ -171,7 +184,35 @@ def __str__(self): cls.__name__ = name cls.__qualname__ = name - return cls + return cls(self.cause) + + def _make_exceptiongroup_dual_exception_instance(self) -> "RayTaskError": + cause_cls = self.cause.__class__ + error_msg = str(self) + + class cls(RayTaskError, cause_cls): + def __new__(cls, cause): + self = super().__new__(cls, cause.message, cause.exceptions) + return self + + def __init__(self, cause): + self.cause = cause + # BaseException implements a __reduce__ method that returns + # a tuple with the type and the value of self.args. + # https://stackoverflow.com/a/49715949/2213289 + self.args = (cause,) + + def __getattr__(self, name): + return getattr(self.cause, name) + + def __str__(self): + return error_msg + + name = f"RayTaskError({cause_cls.__name__})" + cls.__name__ = name + cls.__qualname__ = name + + return cls(self.cause) def as_instanceof_cause(self): """Returns an exception that's an instance of the cause's class. @@ -187,8 +228,7 @@ def as_instanceof_cause(self): return self # already satisfied try: - dual_cls = self.make_dual_exception_type() - return dual_cls(self.cause) + return self.make_dual_exception_instance() except TypeError as e: logger.warning( f"User exception type {type(self.cause)} in RayTaskError can't" diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 71eb9296eb36f..6ff1b83056f39 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -10,6 +10,7 @@ DefaultDeploymentScheduler, DeploymentScheduler, ) +from ray.serve._private.grpc_util import gRPCServer from ray.serve._private.utils import get_head_node_id # NOTE: Please read carefully before changing! @@ -35,3 +36,8 @@ def create_deployment_scheduler( create_placement_group_fn=create_placement_group_fn_override or ray.util.placement_group, ) + + +def add_grpc_address(grpc_server: gRPCServer, server_address: str): + """Helper function to add a address to gRPC server.""" + grpc_server.add_insecure_port(server_address) diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index 157855c549895..ee0e048fef5be 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -7,7 +7,7 @@ import time from abc import ABC, abstractmethod from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type import grpc import starlette @@ -27,6 +27,7 @@ DeploymentID, EndpointInfo, NodeId, + ReplicaID, RequestMetadata, RequestProtocol, ) @@ -40,6 +41,7 @@ SERVE_MULTIPLEXED_MODEL_ID, SERVE_NAMESPACE, ) +from ray.serve._private.default_impl import add_grpc_address from ray.serve._private.grpc_util import DummyServicer, create_serve_grpc_server from ray.serve._private.http_util import ( MessageQueue, @@ -1289,6 +1291,10 @@ def _get_logging_config(self) -> Tuple: log_file_path = handler.baseFilename return log_file_path + def _dump_ingress_replicas_for_testing(self, route: str) -> Set[ReplicaID]: + _, handle, _ = self.http_proxy.proxy_router.match_route(route) + return handle._router._replica_scheduler._replica_id_set + def should_start_grpc_service(self) -> bool: """Determine whether gRPC service should be started. @@ -1408,7 +1414,7 @@ async def run_grpc_server(self): service_handler_factory=self.grpc_proxy.service_handler_factory, ) - grpc_server.add_insecure_port(f"[::]:{self.grpc_port}") + add_grpc_address(grpc_server, f"[::]:{self.grpc_port}") # Dummy servicer is used to be callable for the gRPC server. Serve have a # custom gRPC server implementation to redirect calls into gRPCProxy. diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 9763b591c63e6..3255241abd195 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -13,7 +13,12 @@ from ray.exceptions import RayActorError from ray.serve._private.common import DeploymentID from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME -from ray.serve._private.test_utils import Counter, get_deployment_details, tlog +from ray.serve._private.test_utils import ( + Counter, + check_num_replicas_eq, + get_deployment_details, + tlog, +) def request_with_retries(endpoint, timeout=30): @@ -305,6 +310,8 @@ def check_health(self): tlog(f"Killing replica {replica_to_kill}") ray.kill(ray.get_actor(replica_to_kill, namespace="serve")) + wait_for_condition(check_num_replicas_eq, name="Dummy", target=1) + # The controller just health checked both of them, so it should not # be able to health check and notify the handle router in time. Then # we test that router can properly recognize that the replica has diff --git a/python/ray/serve/tests/test_gcs_failure.py b/python/ray/serve/tests/test_gcs_failure.py index 6bcacd239abe5..437722f4b3bde 100644 --- a/python/ray/serve/tests/test_gcs_failure.py +++ b/python/ray/serve/tests/test_gcs_failure.py @@ -1,6 +1,7 @@ import importlib import os import sys +from typing import Callable, Optional import pytest import requests @@ -8,10 +9,9 @@ import ray from ray import serve from ray._private.test_utils import wait_for_condition -from ray.serve._private.common import DeploymentID, ReplicaState from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME from ray.serve._private.storage.kv_store import KVStoreError, RayInternalKVStore -from ray.serve._private.test_utils import check_apps_running, check_replica_counts +from ray.serve._private.test_utils import check_apps_running from ray.serve.context import _get_global_client from ray.serve.handle import DeploymentHandle from ray.serve.schema import ServeDeploySchema @@ -112,11 +112,20 @@ def call(): def router_populated_with_replicas( - handle: DeploymentHandle, threshold: int, + handle: Optional[DeploymentHandle] = None, + get_replicas_func: Optional[Callable] = None, check_cache_populated: bool = False, ): - replicas = handle._router._replica_scheduler._replica_id_set + """Either get router's replica set from `handle` directly, or use + `get_replicas_func` to get replica set. Then check that the number + of replicas in set is at least `threshold`. + """ + if handle: + replicas = handle._router._replica_scheduler._replica_id_set + else: + replicas = get_replicas_func() + assert len(replicas) >= threshold # Return early if we don't need to check cache @@ -142,6 +151,8 @@ def test_new_router_on_gcs_failure(serve_ha, use_proxy: bool): sent to replicas during GCS downtime. """ + _, client = serve_ha + @serve.deployment class Dummy: def __call__(self): @@ -161,7 +172,18 @@ def __call__(self): # waiting for the first request h._get_or_create_router() - wait_for_condition(router_populated_with_replicas, handle=h, threshold=1) + if use_proxy: + proxy_handles = ray.get(client._controller.get_proxies.remote()) + proxy_handle = list(proxy_handles.values())[0] + wait_for_condition( + router_populated_with_replicas, + threshold=2, + get_replicas_func=lambda: ray.get( + proxy_handle._dump_ingress_replicas_for_testing.remote("/") + ), + ) + else: + wait_for_condition(router_populated_with_replicas, threshold=2, handle=h) # Kill GCS server before a single request is sent. ray.worker._global_node.kill_gcs_server() @@ -208,8 +230,8 @@ def test_handle_router_updated_replicas_then_gcs_failure(serve_ha): wait_for_condition( router_populated_with_replicas, - handle=h, threshold=2, + handle=h, check_cache_populated=True, ) @@ -250,15 +272,15 @@ def test_proxy_router_updated_replicas_then_gcs_failure(serve_ha): config["deployments"][0]["num_replicas"] = 2 client.deploy_apps(ServeDeploySchema(**{"applications": [config]})) - # There is no way to directly check if proxy has received updated replicas, - # so just check for the status. After controller updates status with new - # replicas, proxy should instantly receive updates from long poll + proxy_handles = ray.get(client._controller.get_proxies.remote()) + proxy_handle = list(proxy_handles.values())[0] + wait_for_condition( - check_replica_counts, - controller=client._controller, - deployment_id=DeploymentID("GetPID", "default"), - total=2, - by_state=[(ReplicaState.RUNNING, 2, None)], + router_populated_with_replicas, + threshold=2, + get_replicas_func=lambda: ray.get( + proxy_handle._dump_ingress_replicas_for_testing.remote("/") + ), ) # Kill GCS server before router gets to send request to second replica diff --git a/python/ray/serve/tests/unit/test_grpc_util.py b/python/ray/serve/tests/unit/test_grpc_util.py index 20d63dedaa097..6270ecac31a7c 100644 --- a/python/ray/serve/tests/unit/test_grpc_util.py +++ b/python/ray/serve/tests/unit/test_grpc_util.py @@ -6,6 +6,7 @@ from google.protobuf.any_pb2 import Any as AnyProto from ray import cloudpickle +from ray.serve._private.default_impl import add_grpc_address from ray.serve._private.grpc_util import ( DummyServicer, create_serve_grpc_server, @@ -15,6 +16,14 @@ from ray.serve.grpc_util import RayServegRPCContext +class FakeGrpcServer: + def __init__(self): + self.address = None + + def add_insecure_port(self, address): + self.address = address + + def fake_service_handler_factory(service_method: str, stream: bool) -> Callable: def foo() -> bytes: return f"{'stream' if stream else 'unary'} call from {service_method}".encode() @@ -120,6 +129,15 @@ def test_ray_serve_grpc_context_serializable(): assert deserialized_context.__dict__ == context.__dict__ +def test_add_grpc_address(): + """Test `add_grpc_address` adds the address to the gRPC server.""" + fake_grpc_server = FakeGrpcServer() + grpc_address = "fake_address:50051" + assert fake_grpc_server.address is None + add_grpc_address(fake_grpc_server, grpc_address) + assert fake_grpc_server.address == grpc_address + + if __name__ == "__main__": import sys diff --git a/python/ray/serve/tests/unit/test_pow_2_replica_scheduler.py b/python/ray/serve/tests/unit/test_pow_2_replica_scheduler.py index e078885486c51..237a6c046a007 100644 --- a/python/ray/serve/tests/unit/test_pow_2_replica_scheduler.py +++ b/python/ray/serve/tests/unit/test_pow_2_replica_scheduler.py @@ -1253,7 +1253,7 @@ async def test_multiple_queries_with_different_model_ids(self, pow_2_scheduler): ), ] - done, _ = await asyncio.wait(tasks, timeout=0.01) + done, _ = await asyncio.wait(tasks, timeout=0.1) assert len(done) == len(tasks) assert all( @@ -1600,7 +1600,7 @@ async def test_queue_len_cache_replica_at_capacity_is_probed(pow_2_scheduler): s.replica_queue_len_cache.update(r1.replica_id, DEFAULT_MAX_ONGOING_REQUESTS) task = loop.create_task(s.choose_replica_for_request(fake_pending_request())) - done, _ = await asyncio.wait([task], timeout=0.01) + done, _ = await asyncio.wait([task], timeout=0.1) assert len(done) == 0 # 1 probe from scheduling requests # + 1 probe from when the replica set was updated with replica r1 @@ -1608,7 +1608,7 @@ async def test_queue_len_cache_replica_at_capacity_is_probed(pow_2_scheduler): # Now let the replica respond and accept the request, it should be scheduled. r1.set_queue_len_response(DEFAULT_MAX_ONGOING_REQUESTS - 1) - done, _ = await asyncio.wait([task], timeout=0.01) + done, _ = await asyncio.wait([task], timeout=0.1) assert len(done) == 1 assert (await task) == r1 @@ -1636,7 +1636,7 @@ async def test_queue_len_cache_background_probing(pow_2_scheduler): s.replica_queue_len_cache.update(r1.replica_id, 0) task = loop.create_task(s.choose_replica_for_request(fake_pending_request())) - done, _ = await asyncio.wait([task], timeout=0.01) + done, _ = await asyncio.wait([task], timeout=0.1) assert len(done) == 1 assert (await task) == r1 # 0 probes from scheduling requests diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 06641d0d57776..7b3be66ef2373 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -291,6 +291,7 @@ py_test_module_list( "test_debug_tools.py", "test_distributed_sort.py", "test_environ.py", + "test_exceptiongroup.py", "test_get_or_create_actor.py", "test_ids.py", "test_list_actors.py", diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index ead0ec9648adb..4ee34d2b73c97 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -688,6 +688,11 @@ def tmp_working_dir(): with hello_file.open(mode="w") as f: f.write("world") + test_file_module = path / "file_module.py" + with test_file_module.open(mode="w") as f: + f.write("def hello():\n") + f.write(" return 'hello'\n") + module_path = path / "test_module" module_path.mkdir(parents=True) diff --git a/python/ray/tests/test_basic_4.py b/python/ray/tests/test_basic_4.py index 47dae1fdeafb5..269062e789516 100644 --- a/python/ray/tests/test_basic_4.py +++ b/python/ray/tests/test_basic_4.py @@ -81,7 +81,11 @@ def get_num_workers(): time_waited = time.time() - start print(f"Waited {time_waited} for debug_state.txt to be updated") - # Check that no more workers started for a while. + # Check that no more workers started for a while. Note at initializtion there can + # be more workers prestarted and then idle-killed, so we tolerate at most one spike + # in the number of workers. + high_watermark = 16 + prev = high_watermark for i in range(100): # Sometimes the debug state file can be empty. Retry if needed. for _ in range(3): @@ -91,8 +95,17 @@ def get_num_workers(): time.sleep(0.05) else: break - assert num == 16 + if num >= high_watermark: + # spike climbing + high_watermark = num + prev = num + else: + # spike falling + assert num <= prev + prev = num time.sleep(0.1) + print(f"High watermark: {high_watermark}, prev: {prev}, num: {num}") + assert num == 16 @pytest.mark.skipif( diff --git a/python/ray/tests/test_exceptiongroup.py b/python/ray/tests/test_exceptiongroup.py new file mode 100644 index 0000000000000..e8516a0a20574 --- /dev/null +++ b/python/ray/tests/test_exceptiongroup.py @@ -0,0 +1,196 @@ +import os +import sys +from textwrap import dedent + +import pytest + +import ray +from ray.exceptions import RayTaskError + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="ExceptionGroup is only available in Python 3.11+", +) + + +def test_baseexceptiongroup_task(ray_start_regular): + baseexceptiongroup = BaseExceptionGroup( # noqa: F821 + "test baseexceptiongroup", [BaseException("abc")] + ) + + @ray.remote + def task(): + raise baseexceptiongroup + + with pytest.raises(ray.exceptions.WorkerCrashedError): + ray.get(task.remote()) + + +def test_baseexceptiongroup_actor(ray_start_regular): + baseexceptiongroup = BaseExceptionGroup( # noqa: F821 + "test baseexceptiongroup", [BaseException("abc")] + ) + + @ray.remote + class Actor: + def f(self): + raise baseexceptiongroup + + with pytest.raises(ray.exceptions.ActorDiedError): + a = Actor.remote() + ray.get(a.f.remote()) + + +def test_except_exceptiongroup(ray_start_regular): + exceptiongroup = ExceptionGroup( # noqa: F821 + "test exceptiongroup", [ValueError(), TypeError()] + ) + + @ray.remote + def task(): + raise exceptiongroup + + @ray.remote + class Actor: + def f(self): + raise exceptiongroup + + try: + ray.get(task.remote()) + except Exception as ex: + assert isinstance(ex, RayTaskError) + assert isinstance(ex, ExceptionGroup) # noqa: F821 + assert len(ex.exceptions) == 2 + assert isinstance(ex.exceptions[0], ValueError) + assert isinstance(ex.exceptions[1], TypeError) + + try: + a = Actor.remote() + ray.get(a.f.remote()) + except Exception as ex: + assert isinstance(ex, RayTaskError) + assert isinstance(ex, ExceptionGroup) # noqa: F821 + assert len(ex.exceptions) == 2 + assert isinstance(ex.exceptions[0], ValueError) + assert isinstance(ex.exceptions[1], TypeError) + + +def test_except_star_exception(ray_start_regular): + @ray.remote + def task(): + raise ValueError + + @ray.remote + class Actor: + def f(self): + raise ValueError + + # TODO: Don't use exec() when we only support Python 3.11+ + # Here the exec() is used to avoid SyntaxError for except* for Python < 3.11 + python_code = dedent( + """\ + try: + ray.get(task.remote()) + except* RayTaskError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], RayTaskError) + assert isinstance(ex.exceptions[0], ValueError) + + try: + ray.get(task.remote()) + except* ValueError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], RayTaskError) + assert isinstance(ex.exceptions[0], ValueError) + + try: + a = Actor.remote() + ray.get(a.f.remote()) + except* RayTaskError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], RayTaskError) + assert isinstance(ex.exceptions[0], ValueError) + + try: + a = Actor.remote() + ray.get(a.f.remote()) + except* ValueError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], RayTaskError) + assert isinstance(ex.exceptions[0], ValueError) + """ + ) + exec(python_code) + + +def test_except_star_exceptiongroup(ray_start_regular): + exceptiongroup = ExceptionGroup( # noqa: F821 + "test exceptiongroup", [ValueError(), TypeError()] + ) + + @ray.remote + def task(): + raise exceptiongroup + + @ray.remote + class Actor: + def f(self): + raise exceptiongroup + + # TODO: Don't use exec() when we only support Python 3.11+ + # Here the exec() is used to avoid SyntaxError for except* for Python < 3.11 + python_code = dedent( + """\ + try: + ray.get(task.remote()) + except* RayTaskError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 2 + assert isinstance(ex.exceptions[0], ValueError) + assert isinstance(ex.exceptions[1], TypeError) + + try: + ray.get(task.remote()) + except* ValueError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], ValueError) + except* TypeError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], TypeError) + + try: + a = Actor.remote() + ray.get(a.f.remote()) + except* RayTaskError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 2 + assert isinstance(ex.exceptions[0], ValueError) + assert isinstance(ex.exceptions[1], TypeError) + + try: + a = Actor.remote() + ray.get(a.f.remote()) + except* ValueError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], ValueError) + except* TypeError as ex: + assert isinstance(ex, ExceptionGroup) + assert len(ex.exceptions) == 1 + assert isinstance(ex.exceptions[0], TypeError) + """ + ) + exec(python_code) + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 072005becfc2a..cf4fb5adde583 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -322,6 +322,26 @@ def foo(): assert isinstance(ex, RayTaskError) +def test_baseexception_task(ray_start_regular): + @ray.remote + def task(): + raise BaseException("abc") + + with pytest.raises(ray.exceptions.WorkerCrashedError): + ray.get(task.remote()) + + +def test_baseexception_actor(ray_start_regular): + @ray.remote + class Actor: + def f(self): + raise BaseException("abc") + + with pytest.raises(ray.exceptions.ActorDiedError): + a = Actor.remote() + ray.get(a.f.remote()) + + @pytest.mark.skip("This test does not work yet.") @pytest.mark.parametrize("ray_start_object_store_memory", [10**6], indirect=True) def test_put_error1(ray_start_object_store_memory, error_pubsub): diff --git a/python/ray/tests/test_gcs_utils.py b/python/ray/tests/test_gcs_utils.py index 82f34214046ef..c25beac6e598a 100644 --- a/python/ray/tests/test_gcs_utils.py +++ b/python/ray/tests/test_gcs_utils.py @@ -100,6 +100,20 @@ def test_kv_timeout(ray_start_regular): gcs_client.internal_kv_del(b"A", True, b"NS", timeout=2) +def test_kv_transient_network_error(shutdown_only, monkeypatch): + monkeypatch.setenv( + "RAY_testing_rpc_failure", + "ray::rpc::InternalKVGcsService.grpc_client.InternalKVGet=5," + "ray::rpc::InternalKVGcsService.grpc_client.InternalKVPut=5", + ) + ray.init() + gcs_address = ray._private.worker.global_worker.gcs_client.address + gcs_client = ray._raylet.GcsClient(address=gcs_address, nums_reconnect_retry=0) + + gcs_client.internal_kv_put(b"A", b"Hello", True, b"") + assert gcs_client.internal_kv_get(b"A", b"") == b"Hello" + + @pytest.mark.asyncio async def test_kv_basic_aio(ray_start_regular): gcs_client = gcs_utils.GcsAioClient( diff --git a/python/ray/tests/test_logging_2.py b/python/ray/tests/test_logging_2.py index eb32201d589b8..b48b04e44a593 100644 --- a/python/ray/tests/test_logging_2.py +++ b/python/ray/tests/test_logging_2.py @@ -387,6 +387,29 @@ def print_message(self): for s in should_not_exist: assert s not in stderr + def test_text_mode_driver_system_log(self, shutdown_only): + script = """ +import ray +ray.init( + logging_config=ray.LoggingConfig(encoding="TEXT") +) +""" + stderr = run_string_as_driver(script) + should_exist = "timestamp_ns=" + assert should_exist in stderr + + +def test_structured_logging_with_working_dir(tmp_path, shutdown_only): + working_dir = tmp_path / "test-working-dir" + working_dir.mkdir() + runtime_env = { + "working_dir": str(working_dir), + } + ray.init( + runtime_env=runtime_env, + logging_config=ray.LoggingConfig(encoding="TEXT"), + ) + class TestSetupLogRecordFactory: @pytest.fixture @@ -424,6 +447,30 @@ def existing_factory(*args, **kwargs): assert record.__dict__["existing_factory"] +def test_text_mode_no_prefix(shutdown_only): + """ + If logging_config is set, remove the prefix that contains + the actor or task's name and their PIDs. + """ + script = """ +import ray +import logging +ray.init( + logging_config=ray.LoggingConfig(encoding="TEXT") +) +@ray.remote +class MyActor: + def print_message(self): + logger = logging.getLogger(__name__) + logger.info("This is a Ray actor") +my_actor = MyActor.remote() +ray.get(my_actor.print_message.remote()) +""" + stderr = run_string_as_driver(script) + assert "This is a Ray actor" in stderr + assert "(MyActor pid=" not in stderr + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/python/ray/tests/test_metrics_agent.py b/python/ray/tests/test_metrics_agent.py index adc6a74869271..990e917921b75 100644 --- a/python/ray/tests/test_metrics_agent.py +++ b/python/ray/tests/test_metrics_agent.py @@ -59,7 +59,9 @@ "ray_internal_num_spilled_tasks", # "ray_unintentional_worker_failures_total", # "ray_node_failure_total", - "ray_grpc_server_req_process_time_ms", + "ray_grpc_server_req_process_time_ms_sum", + "ray_grpc_server_req_process_time_ms_bucket", + "ray_grpc_server_req_process_time_ms_count", "ray_grpc_server_req_new_total", "ray_grpc_server_req_handling_total", "ray_grpc_server_req_finished_total", @@ -332,7 +334,9 @@ def test_cases(): # Make sure the gRPC stats are not reported from workers. We disabled # it there because it has too high cardinality. grpc_metrics = [ - "ray_grpc_server_req_process_time_ms", + "ray_grpc_server_req_process_time_ms_sum", + "ray_grpc_server_req_process_time_ms_bucket", + "ray_grpc_server_req_process_time_ms_count", "ray_grpc_server_req_new_total", "ray_grpc_server_req_handling_total", "ray_grpc_server_req_finished_total", diff --git a/python/ray/tests/test_node_manager.py b/python/ray/tests/test_node_manager.py index be746002b6dac..98e7dc27f6087 100644 --- a/python/ray/tests/test_node_manager.py +++ b/python/ray/tests/test_node_manager.py @@ -14,6 +14,13 @@ from ray._private.utils import get_num_cpus import time import sys +from ray._private.runtime_env.context import RuntimeEnvContext +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from typing import List, Optional +import logging +import tempfile +import collections +import shutil # This tests the queue transitions for infeasible tasks. This has been an issue @@ -396,6 +403,128 @@ def f(): assert used_worker_pids == worker_pids +MyPlugin = "HangOnSecondWorkerPlugin" +MY_PLUGIN_CLASS_PATH = "ray.tests.test_node_manager.HangOnSecondWorkerPlugin" +PLUGIN_TIMEOUT = 10 + + +class HangOnSecondWorkerPlugin(RuntimeEnvPlugin): + """ + The first worker will start up normally, but all subsequent workers will hang at + start up indefinitely. How it works: Ray RuntimeEnvAgent caches the modified context + so we can't do it in modify_context. Instead, we use a bash command to read a file + and hang forever. We don't have a good file lock mechanism in bash (flock is not + installed by default in macos), so we also serialize the worker startup. + """ + + name = MyPlugin + + def __init__(self): + # Each URI has a temp dir, a counter file, and a hang.sh script. + self.uris = collections.defaultdict(dict) + + def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821 + return [runtime_env[self.name]] + + async def create( + self, + uri: Optional[str], + runtime_env, + context: RuntimeEnvContext, + logger: logging.Logger, + ) -> float: + d = self.uris[uri] + d["temp_dir"] = tempfile.mkdtemp() + logger.info(f"caching temp dir {d['temp_dir']} for uri {uri}") + d["counter_file"] = os.path.join(d["temp_dir"], "script_run_count") + with open(d["counter_file"], "w+") as f: + f.write("0") + d["hang_sh"] = os.path.join(d["temp_dir"], "hang.sh") + with open(d["hang_sh"], "w+") as f: + f.write( + f"""#!/bin/bash + +counter_file="{d['counter_file']}" + +count=$(cat "$counter_file") + +if [ "$count" -eq "0" ]; then + echo "1" > "$counter_file" + echo "first time run" + exit 0 +elif [ "$count" -eq "1" ]; then + echo "2" > "$counter_file" + echo "second time run, sleeping..." + sleep 1000 +fi +""" + ) + os.chmod(d["hang_sh"], 0o755) + return 0.1 + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", # noqa: F821 + ctx: RuntimeEnvContext, + logger: logging.Logger, + ) -> None: + logger.info(f"Starting worker: {uris}, {runtime_env}") + if self.name not in runtime_env: + return + assert len(uris) == 1 + uri = uris[0] + hang_sh = self.uris[uri]["hang_sh"] + ctx.command_prefix += ["bash", hang_sh, "&&"] + + def delete_uri(self, uri: str, logger: logging.Logger) -> float: + temp_dir = self.uris[uri]["temp_dir"] + shutil.rmtree(temp_dir) + del self.uris[uri] + logger.info(f"temp_dir removed: {temp_dir}") + + +@pytest.fixture +def serialize_worker_startup(monkeypatch): + """Only one worker starts up each time, since our bash script is not process-safe""" + monkeypatch.setenv("RAY_worker_maximum_startup_concurrency", "1") + yield + + +@pytest.mark.parametrize( + "set_runtime_env_plugins", + [ + '[{"class":"' + MY_PLUGIN_CLASS_PATH + '"}]', + ], + indirect=True, +) +def test_can_reuse_released_workers( + serialize_worker_startup, set_runtime_env_plugins, ray_start_cluster +): + """ + Uses a runtime env plugin to make sure only 1 worker can start and all subsequent + workers will hang in runtime start up forever. We issue 10 tasks and test that + all the following tasks can still be scheduled on the first worker released from the + first task, i.e. tasks are not binded to the workers that they requested to start. + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + + @ray.remote(runtime_env={"env_vars": {"HELLO": "WORLD"}, MyPlugin: "key"}) + def f(): + # Sleep for a while to make sure other tasks also request workers. + time.sleep(1) + print(f"pid={os.getpid()}, env HELLO={os.environ.get('HELLO')}") + return os.getpid() + + objs = [f.remote() for i in range(10)] + + pids = ray.get(objs) + for pid in pids: + assert pid == pids[0] + + if __name__ == "__main__": import sys diff --git a/python/ray/tests/test_runtime_env_packaging.py b/python/ray/tests/test_runtime_env_packaging.py index 3c10826ad0361..c2a82318f5a4a 100644 --- a/python/ray/tests/test_runtime_env_packaging.py +++ b/python/ray/tests/test_runtime_env_packaging.py @@ -29,6 +29,7 @@ download_and_unpack_package, get_local_dir_from_uri, get_top_level_dir_from_compressed_package, + get_uri_for_file, get_uri_for_directory, get_uri_for_package, is_whl_uri, @@ -64,6 +65,14 @@ def random_string(size: int = 10): return "".join(random.choice(string.ascii_uppercase) for _ in range(size)) +@pytest.fixture +def random_file(tmp_path) -> Path: + p = tmp_path / (random_string(10) + ".py") + with p.open("w") as f: + f.write(random_string(100)) + yield p + + @pytest.fixture def random_dir(tmp_path) -> Path: subdir = tmp_path / "subdir" @@ -135,6 +144,38 @@ def random_zip_file_with_top_level_dir(tmp_path): yield str(path / ARCHIVE_NAME) +class TestGetURIForFile: + def test_invalid_file(self): + with pytest.raises(ValueError): + get_uri_for_file("/does/not/exist.py") + + with pytest.raises(ValueError): + get_uri_for_file("does/not/exist.py") + + def test_determinism(self, random_file): + # Check that it's deterministic for same data. + uris = {get_uri_for_file(str(random_file)) for _ in range(10)} + assert len(uris) == 1 + + # Append one line, should be different now. + with open(random_file, "a") as f: + f.write(random_string()) + + assert {get_uri_for_file(str(random_file))} != uris + + def test_relative_paths(self, random_file): + # Check that relative or absolute paths result in the same URI. + p = Path(random_file) + relative_uri = get_uri_for_file(os.path.relpath(p)) + absolute_uri = get_uri_for_file(str(p.resolve())) + assert relative_uri == absolute_uri + + def test_uri_hash_length(self, random_file): + uri = get_uri_for_file(str(random_file)) + hex_hash = uri.split("_")[-1][: -len(".zip")] + assert len(hex_hash) == 16 + + class TestGetURIForDirectory: def test_invalid_directory(self): with pytest.raises(ValueError): diff --git a/python/ray/tests/test_runtime_env_working_dir.py b/python/ray/tests/test_runtime_env_working_dir.py index f145eea151f89..e667b0c712b10 100644 --- a/python/ray/tests/test_runtime_env_working_dir.py +++ b/python/ray/tests/test_runtime_env_working_dir.py @@ -128,6 +128,7 @@ def call_ray_init(): runtime_env={ "py_modules": [ str(Path(tmp_working_dir) / "test_module"), + str(Path(tmp_working_dir) / "file_module.py"), Path(os.path.dirname(__file__)) / "pip_install_test-0.5-py3-none-any.whl", ] @@ -140,6 +141,7 @@ def call_ray_init(): "working_dir": tmp_working_dir, "py_modules": [ str(Path(tmp_working_dir) / "test_module"), + str(Path(tmp_working_dir) / "file_module.py"), Path(os.path.dirname(__file__)) / "pip_install_test-0.5-py3-none-any.whl", ], @@ -163,15 +165,16 @@ def reinit(): @ray.remote def test_import(): import test_module + import file_module assert TEST_IMPORT_DIR in os.environ.get("PYTHONPATH", "") - return test_module.one() + return test_module.one(), file_module.hello() if option == "failure": with pytest.raises(ImportError): ray.get(test_import.remote()) else: - assert ray.get(test_import.remote()) == 1 + assert ray.get(test_import.remote()) == (1, "hello") if option in {"py_modules", "working_dir_and_py_modules"}: @@ -205,9 +208,10 @@ def test_read(): class Actor: def test_import(self): import test_module + import file_module assert TEST_IMPORT_DIR in os.environ.get("PYTHONPATH", "") - return test_module.one() + return test_module.one(), file_module.hello() def test_read(self): assert TEST_IMPORT_DIR in os.environ.get("PYTHONPATH", "") @@ -216,11 +220,11 @@ def test_read(self): a = Actor.remote() if option == "failure": with pytest.raises(ImportError): - assert ray.get(a.test_import.remote()) == 1 + assert ray.get(a.test_import.remote()) == (1, "hello") with pytest.raises(FileNotFoundError): assert ray.get(a.test_read.remote()) == "world" elif option in {"working_dir_and_py_modules", "working_dir"}: - assert ray.get(a.test_import.remote()) == 1 + assert ray.get(a.test_import.remote()) == (1, "hello") assert ray.get(a.test_read.remote()) == "world" @@ -243,7 +247,10 @@ def call_ray_init(): ray.init( address, runtime_env={ - "py_modules": [os.path.join(tmp_working_dir, "test_module")] + "py_modules": [ + os.path.join(tmp_working_dir, "test_module"), + os.path.join(tmp_working_dir, "file_module.py"), + ] }, ) @@ -262,31 +269,32 @@ def reinit(): # Import in the driver. sys.path.insert(0, tmp_working_dir) import test_module + import file_module @ray.remote def test_import(): - return test_module.one() + return test_module.one(), file_module.hello() if option == "failure": with pytest.raises(Exception): ray.get(test_import.remote()) else: - assert ray.get(test_import.remote()) == 1 + assert ray.get(test_import.remote()) == (1, "hello") reinit() @ray.remote class Actor: def test_import(self): - return test_module.one() + return test_module.one(), file_module.hello() if option == "failure": with pytest.raises(Exception): a = Actor.remote() - assert ray.get(a.test_import.remote()) == 1 + assert ray.get(a.test_import.remote()) == (1, "hello") else: a = Actor.remote() - assert ray.get(a.test_import.remote()) == 1 + assert ray.get(a.test_import.remote()) == (1, "hello") def test_empty_working_dir(start_cluster): diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index ca5be44aee0c3..908faefcd9fd1 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -13,7 +13,7 @@ from ray._private.state_api_test_utils import get_state_api_manager from ray.util.state import get_job from ray.dashboard.modules.job.pydantic_models import JobDetails -from ray.util.state.common import Humanify +from ray.util.state.common import Humanify, PredicateType from ray._private.gcs_utils import GcsAioClient import yaml from click.testing import CliRunner @@ -328,7 +328,7 @@ def generate_runtime_env_info(runtime_env, creation_time=None, success=True): def create_api_options( timeout: int = DEFAULT_RPC_TIMEOUT, limit: int = DEFAULT_LIMIT, - filters: List[Tuple[str, SupportedFilterType]] = None, + filters: List[Tuple[str, PredicateType, SupportedFilterType]] = None, detail: bool = False, exclude_driver: bool = True, ): diff --git a/python/ray/tests/tls/README b/python/ray/tests/tls/README new file mode 100644 index 0000000000000..0115bebbc7a10 --- /dev/null +++ b/python/ray/tests/tls/README @@ -0,0 +1,27 @@ +These files are generated with the following command: + +mkdir -p {str(tmp_path)}/tls +openssl genrsa -out {str(tmp_path)}/tls/ca.key 4096 +openssl req \ + -x509 -new -nodes -sha256 \ + -key {str(tmp_path)}/tls/ca.key \ + -days 3650 \ + -subj '/O=Redis Test/CN=Certificate Authority' \ + -out {str(tmp_path)}/tls/ca.crt +openssl genrsa -out {str(tmp_path)}/tls/redis.key 2048 +openssl req \ + -new -sha256 \ + -key {str(tmp_path)}/tls/redis.key \ + -subj '/O=Redis Test/CN=Server' | \ + openssl x509 \ + -req -sha256 \ + -CA {str(tmp_path)}/tls/ca.crt \ + -CAkey {str(tmp_path)}/tls/ca.key \ + -CAserial {str(tmp_path)}/tls/ca.txt \ + -CAcreateserial \ + -days 3650 \ + -out {str(tmp_path)}/tls/redis.crt +openssl dhparam -out {str(tmp_path)}/tls/redis.dh 2048 + + +See https://github.com/ray-project/ray/pull/40378/ for more details \ No newline at end of file diff --git a/python/ray/tests/tls/ca.txt b/python/ray/tests/tls/ca.txt index 55fca9ab4c0f4..3c937ac1ca773 100644 --- a/python/ray/tests/tls/ca.txt +++ b/python/ray/tests/tls/ca.txt @@ -1 +1 @@ -75703BF2CC43AFFC5692C7B72687A196C4040599 +75703BF2CC43AFFC5692C7B72687A196C404059A diff --git a/python/ray/tests/tls/redis.crt b/python/ray/tests/tls/redis.crt index 19baa8870fc92..53777f13eddf7 100644 --- a/python/ray/tests/tls/redis.crt +++ b/python/ray/tests/tls/redis.crt @@ -1,7 +1,7 @@ -----BEGIN CERTIFICATE----- -MIID4jCCAcoCFHVwO/LMQ6/8VpLHtyaHoZbEBAWZMA0GCSqGSIb3DQEBCwUAMDUx +MIID4jCCAcoCFHVwO/LMQ6/8VpLHtyaHoZbEBAWaMA0GCSqGSIb3DQEBCwUAMDUx EzARBgNVBAoMClJlZGlzIFRlc3QxHjAcBgNVBAMMFUNlcnRpZmljYXRlIEF1dGhv -cml0eTAeFw0yMzEwMTUwMjExNDNaFw0yNDEwMTQwMjExNDNaMCYxEzARBgNVBAoM +cml0eTAeFw0yNDEwMTUwNDM0MzlaFw0zNDEwMTMwNDM0MzlaMCYxEzARBgNVBAoM ClJlZGlzIFRlc3QxDzANBgNVBAMMBlNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD ggEPADCCAQoCggEBAMLnixXdFw0vS4ewwRfkYEcOWDBxCmW8jj9MUmQ/QZmd01Ic Ybs9AFf+t4p6Z5woggVLBNdjgYNnKCwP94cL4cGNKDIyP2NHl5IJ7KJHsBpeibys @@ -9,15 +9,15 @@ Ybs9AFf+t4p6Z5woggVLBNdjgYNnKCwP94cL4cGNKDIyP2NHl5IJ7KJHsBpeibys Tgqa88FSxTi5PenSpmvykPiK4ZY35ZG/j1hUqHFwmFESCLTQJZzFP5NmXUpumJG3 R2d4dpaYzrAKzHt1qnd8ByI2X/I8USCFbwEqunmsKxmFDZZWyOS/+d96THHEYJe3 ZImFMIXAdvdLou1MYfvO3tamXIyRh+zedr/Cp1kCAwEAATANBgkqhkiG9w0BAQsF -AAOCAgEAnVXfIi4NWv/VxbV/ylLkTLC5zY2bn+oiuiapPilCo4858fbdmiwKk7IC -mCksbY4MN2+nelKzypre8HwKDv+MoLKyGO3mo+m6P9xS0BIvMeYNJOgLld81dgP+ -pLShP3LN+5u26gmGjhABFEpUOolX5m75bRq8BPoA4hAg2CNVPZt7GVyPXWqgl/PA -OKWrtrsQBbzL1yLkZFgeDeWKPwr7qfAPKG3Hr4yhURv2jKzJa7u6F+zeqkNR+pab -SZRj1b6YPJLZDUznQ8eO5XiwtCqCkaMixfLKStTQ8RtL3AtSglitiISZ7MKF9tha -+EEMgKMhYvhhImOuaMydePT8BRqTL495Bd3lzDa9MhooID2Ur0qE/TmwPH5fCo1r -+olZHoTMrMftlWsJQVQrOafsvAM5Df+yqeyjQXdEUY0cLYp/q9Dg99WUl0sct2G0 -PDd165xsnEBCoBqvNdbYJjdKfFk2GaLgjQB2CntbhRgDz+pt6glXDTexD0mfSa30 -dJoB1wI94XerOxQdPur+XDL3i/W84HT9I5tpcw1ywr7QmFMGTRk+OlNhDBZGkHGu -guBQOn/zwbivtAuGYgXsX42FgRmCwSZOp5sth2x8D4zW5mu1QIbFWRDswCJqQXfd -h0oFOJE3nEmTQmSSsCN+jpV3U1pB5/r0CDNFZNCvZwfoW644HdA= +AAOCAgEAQgZ/uixRdNblqPSGhsFxeAxaZonBVg7akKG49Tk/wg/WwgCxvy76l0WS ++/GgWTXRNqM6BZbXggof/6Nh5pPGTJgcJmSORcEDNnv7R1FVrp4H2TPRVwf0g/NP +n3rqtM6WHIzJ9olv/U3J/U9vcyk42cuYscaJGHFMjoPnqq2ISltYb6lIbOxD6stz +oJPOnZNZyAILWMr5DCIZg67z2+ZZo4Y0epzmfdcg/Xlw6bV9xiOeLIENlG1Rz3a3 +Jw+zMAnU41gbRPZC0hZxq4K3qsCYTe/RmGz4YeF7eBXLHEdFQfz4mq2IzlCz6+Gh +BDgBQeMjqZzXsxk0BzHIPbiPD+WOUV4YZyap9D7b5nv4bDlQIAV2do6MuMLB16n1 +xVl98wbNCGpJvFeqzUlf7djQJmmDrz0egj05oEW7UQt9b1s3zswLZK4wAY0ZtTsZ +QBhuN6Ez/pD4ydsluPXQCBQaBtMlXMsMWCsLZ5N0luAhgqEl268PP03cEwwLCKk4 +Wk2k9Spj8ARKFykLxeLvlJSbfVUioKwduXBmcKdwBkJEoSqW51FlHXwPApq2r1uD +scWCRzs0H0t1H/BX2RHhDQrNvXZffBMkwTkXUOz5wLTHiynnKzgcBea0WZyhlOfK +0rl5SgsIeNbxydAL/TGydBpqA+MMmUpelEl6JEvQ8YAuhbNZGO4= -----END CERTIFICATE----- diff --git a/python/ray/util/state/state_cli.py b/python/ray/util/state/state_cli.py index ce40b9310cc03..d191b34b3c2c0 100644 --- a/python/ray/util/state/state_cli.py +++ b/python/ray/util/state/state_cli.py @@ -81,7 +81,6 @@ def _parse_filter(filter: str) -> Tuple[str, PredicateType, SupportedFilterType] filter[predicate_index[0] : predicate_index[1]], filter[predicate_index[1] :], ) - assert predicate == "=" or predicate == "!=" if len(key) == 0 or len(value) == 0: raise ValueError( diff --git a/python/requirements/anyscale-requirements.txt b/python/requirements/anyscale-requirements.txt index 3fef80ca3868a..709e860a9e02d 100644 --- a/python/requirements/anyscale-requirements.txt +++ b/python/requirements/anyscale-requirements.txt @@ -4,7 +4,8 @@ opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp google-cloud-storage -grpcio==1.60.0 +grpcio>=1.66.1 +grpcio-tools pyyaml pyopenssl certifi diff --git a/python/requirements_compiled.txt b/python/requirements_compiled.txt index 5303ae2d23196..a1043afc5b51b 100644 --- a/python/requirements_compiled.txt +++ b/python/requirements_compiled.txt @@ -682,13 +682,14 @@ graphql-core==3.2.3 # via moto greenlet==3.0.1 # via sqlalchemy -grpcio==1.60.0 ; sys_platform != "darwin" +grpcio==1.66.2 ; sys_platform != "darwin" # via # -r /ray/ci/../python/requirements.txt # -r /ray/ci/../python/requirements/anyscale-requirements.txt # google-api-core # google-cloud-bigquery # grpcio-status + # grpcio-tools # mlagents-envs # opencensus-proto # opentelemetry-exporter-opencensus @@ -697,6 +698,8 @@ grpcio==1.60.0 ; sys_platform != "darwin" # tensorflow grpcio-status==1.48.2 # via google-api-core +grpcio-tools==1.48.2 + # via -r /ray/ci/../python/requirements/anyscale-requirements.txt gsutil==5.27 # via -r /ray/ci/../python/requirements/docker/ray-docker-requirements.txt gunicorn==20.1.0 @@ -1550,6 +1553,7 @@ protobuf==3.20.3 # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status + # grpcio-tools # mlagents-envs # mlflow # onnx diff --git a/rllib/BUILD b/rllib/BUILD index 469a5c57a9509..7c2259b84768e 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2836,107 +2836,6 @@ py_test( args = ["--enable-new-api-stack", "--as-test"] ) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_local_cpu_torch", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=torch", "--config=local-cpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_local_cpu_tf2", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=tf2", "--config=local-cpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_local_gpu_torch", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples", "gpu"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=torch", "--config=local-gpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_local_gpu_tf2", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples", "gpu", "exclusive"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=tf2", "--config=local-gpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_remote_cpu_torch", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=torch", "--config=remote-cpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_remote_cpu_tf2", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=tf2", "--config=remote-cpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_remote_gpu_torch", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples", "gpu", "exclusive"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=torch", "--config=remote-gpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_remote_gpu_tf2", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples", "gpu", "exclusive"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=tf2", "--config=remote-gpu"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_multi_gpu_torch", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples", "multi_gpu", "exclusive"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=torch", "--config=multi-gpu-ddp"] -) - -#@OldAPIStack @HybridAPIStack -py_test( - name = "examples/learners/ppo_tuner_multi_gpu_tf2", - main = "examples/learners/ppo_tuner.py", - tags = ["team:rllib", "examples", "multi_gpu", "exclusive"], - size = "medium", - srcs = ["examples/learners/ppo_tuner.py"], - args = ["--framework=tf2", "--config=multi-gpu-ddp"] -) - # subdirectory: multi_agent/ # .................................... py_test( @@ -3256,56 +3155,6 @@ py_test( args = ["--as-test", "--framework=torch", "--stop-reward=-0.012", "--num-cpus=4"] ) -#@OldAPIStack -py_test( - name = "examples/cartpole_lstm_impala_tf2", - main = "examples/cartpole_lstm.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/cartpole_lstm.py"], - args = ["--run=IMPALA", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"] -) - -#@OldAPIStack -py_test( - name = "examples/cartpole_lstm_impala_torch", - main = "examples/cartpole_lstm.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/cartpole_lstm.py"], - args = ["--run=IMPALA", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"] -) - -#@OldAPIStack -py_test( - name = "examples/cartpole_lstm_ppo_tf2", - main = "examples/cartpole_lstm.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "large", - srcs = ["examples/cartpole_lstm.py"], - args = ["--run=PPO", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"] -) - -#@OldAPIStack -py_test( - name = "examples/cartpole_lstm_ppo_torch", - main = "examples/cartpole_lstm.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/cartpole_lstm.py"], - args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"] -) - -#@OldAPIStack -py_test( - name = "examples/cartpole_lstm_ppo_torch_with_prev_a_and_r", - main = "examples/cartpole_lstm.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/cartpole_lstm.py"], - args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4", "--use-prev-action", "--use-prev-reward"] -) - #@OldAPIStack py_test( name = "examples/centralized_critic_tf", @@ -3356,30 +3205,6 @@ py_test( args = ["--stop-iters=2"] ) -#@OldAPIStack -py_test( - name = "examples/custom_model_loss_and_metrics_ppo_tf", - main = "examples/custom_model_loss_and_metrics.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "small", - # Include the json data file. - data = ["tests/data/cartpole/small.json"], - srcs = ["examples/custom_model_loss_and_metrics.py"], - args = ["--run=PPO", "--stop-iters=1", "--framework=tf","--input-files=tests/data/cartpole"] -) - -#@OldAPIStack -py_test( - name = "examples/custom_model_loss_and_metrics_ppo_torch", - main = "examples/custom_model_loss_and_metrics.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "small", - # Include the json data file. - data = ["tests/data/cartpole/small.json"], - srcs = ["examples/custom_model_loss_and_metrics.py"], - args = ["--run=PPO", "--framework=torch", "--stop-iters=1", "--input-files=tests/data/cartpole"] -) - py_test( name = "examples/custom_recurrent_rnn_tokenizer_repeat_after_me_tf2", main = "examples/custom_recurrent_rnn_tokenizer.py", diff --git a/rllib/algorithms/bc/torch/bc_torch_rl_module.py b/rllib/algorithms/bc/torch/bc_torch_rl_module.py index a547047d7f417..d06c323b124ef 100644 --- a/rllib/algorithms/bc/torch/bc_torch_rl_module.py +++ b/rllib/algorithms/bc/torch/bc_torch_rl_module.py @@ -11,7 +11,7 @@ class BCTorchRLModule(TorchRLModule): @override(RLModule) def setup(self): # __sphinx_doc_begin__ - # Build models from catalog + # Build models from catalog. self.encoder = self.catalog.build_encoder(framework=self.framework) self.pi = self.catalog.build_pi_head(framework=self.framework) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index fb3e34f4339dd..b140d9d13a96c 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -1,5 +1,5 @@ import copy -from dataclasses import dataclass, field +import dataclasses import logging import pprint from typing import ( @@ -288,6 +288,19 @@ def add_module( # has `inference_only=False`. if not module.inference_only: self.inference_only = False + + # Check framework of incoming RLModule against `self.framework`. + if module.framework is not None: + if self.framework is None: + self.framework = module.framework + elif module.framework != self.framework: + raise ValueError( + f"Framework ({module.framework}) of incoming RLModule does NOT " + f"match framework ({self.framework}) of MultiRLModule! If the " + f"added module should not be trained, try setting its framework " + f"to None." + ) + self._rl_modules[module_id] = module # Update our RLModuleSpecs dict, such that - if written to disk - # it'll allow for proper restoring this instance through `.from_checkpoint()`. @@ -553,7 +566,7 @@ def _check_module_exists(self, module_id: ModuleID) -> None: @PublicAPI(stability="alpha") -@dataclass +@dataclasses.dataclass class MultiRLModuleSpec: """A utility spec class to make it constructing MultiRLModules easier. @@ -666,7 +679,11 @@ def build(self, module_id: Optional[ModuleID] = None) -> RLModule: observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=self.model_config, + model_config=( + dataclasses.asdict(self.model_config) + if dataclasses.is_dataclass(self.model_config) + else self.model_config + ), rl_module_specs=self.rl_module_specs, ) # Older custom model might still require the old `MultiRLModuleConfig` under @@ -859,10 +876,10 @@ def get_rl_module_config(self): "module2: [RLModuleSpec], ..}, inference_only=..)", error=False, ) -@dataclass +@dataclasses.dataclass class MultiRLModuleConfig: inference_only: bool = False - modules: Dict[ModuleID, RLModuleSpec] = field(default_factory=dict) + modules: Dict[ModuleID, RLModuleSpec] = dataclasses.field(default_factory=dict) def to_dict(self): return { diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index f1fb5b337cc54..42aa0a780ed45 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -98,7 +98,7 @@ def build(self) -> "RLModule": observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=self.model_config, + model_config=self._get_model_config(), catalog_class=self.catalog_class, ) # Older custom model might still require the old `RLModuleConfig` under diff --git a/rllib/env/utils/infinite_lookback_buffer.py b/rllib/env/utils/infinite_lookback_buffer.py index 269e3827ca205..cd84a8518097f 100644 --- a/rllib/env/utils/infinite_lookback_buffer.py +++ b/rllib/env/utils/infinite_lookback_buffer.py @@ -8,8 +8,10 @@ from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict from ray.rllib.utils.spaces.space_utils import ( batch, + from_jsonable_if_needed, get_dummy_batch_for_space, get_base_struct_from_space, + to_jsonable_if_needed, ) @@ -71,12 +73,11 @@ def get_state(self) -> Dict[str, Any]: A dict containing all the data and metadata from the buffer. """ return { - "data": self.data, + "data": to_jsonable_if_needed(self.data, self.space) + if self.space + else self.data, "lookback": self.lookback, "finalized": self.finalized, - "space_struct": gym_space_to_dict(self.space_struct) - if self.space_struct - else self.space_struct, "space": gym_space_to_dict(self.space) if self.space else self.space, } @@ -92,16 +93,16 @@ def from_state(state: Dict[str, Any]) -> None: from the state dict. """ buffer = InfiniteLookbackBuffer() - buffer.data = state["data"] buffer.lookback = state["lookback"] buffer.finalized = state["finalized"] + buffer.space = gym_space_from_dict(state["space"]) if state["space"] else None buffer.space_struct = ( - gym_space_from_dict(state["space_struct"]) - if state["space_struct"] - else state["space_struct"] + get_base_struct_from_space(buffer.space) if buffer.space else None ) - buffer.space = ( - gym_space_from_dict(state["space"]) if state["space"] else state["space"] + buffer.data = ( + from_jsonable_if_needed(state["data"], buffer.space) + if buffer.space + else state["data"] ) return buffer diff --git a/rllib/examples/cartpole_lstm.py b/rllib/examples/cartpole_lstm.py deleted file mode 100644 index c7454161ab06b..0000000000000 --- a/rllib/examples/cartpole_lstm.py +++ /dev/null @@ -1,94 +0,0 @@ -# @OldAPIStack - -# TODO (sven): Move this script to `examples/rl_modules/...` - -import argparse -import os - -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, -) -from ray.rllib.utils.test_utils import check_learning_achieved -from ray.tune.registry import get_trainable_cls - -parser = argparse.ArgumentParser() -parser.add_argument( - "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use." -) -parser.add_argument("--num-cpus", type=int, default=0) -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", -) -parser.add_argument("--use-prev-action", action="store_true") -parser.add_argument("--use-prev-reward", action="store_true") -parser.add_argument( - "--as-test", - action="store_true", - help="Whether this script should be run as a test: --stop-reward must " - "be achieved within --stop-timesteps AND --stop-iters.", -) -parser.add_argument( - "--stop-iters", type=int, default=200, help="Number of iterations to train." -) -parser.add_argument( - "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train." -) -parser.add_argument( - "--stop-reward", type=float, default=150.0, help="Reward at which we stop training." -) - -if __name__ == "__main__": - import ray - from ray import air, tune - - args = parser.parse_args() - - ray.init() - - algo_cls = get_trainable_cls(args.run) - config = algo_cls.get_default_config() - - config.environment(env=StatelessCartPole).resources( - num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")) - ).framework(args.framework).reporting(min_time_s_per_iteration=0.1).training( - model={ - "use_lstm": True, - "lstm_cell_size": 32, - "lstm_use_prev_action": args.use_prev_action, - "lstm_use_prev_reward": args.use_prev_reward, - } - ) - - if args.run == "PPO": - config.training(num_epochs=5, vf_loss_coeff=0.0001, train_batch_size=512) - config.model["vf_share_layers"] = True - elif args.run == "IMPALA": - config.env_runners(num_env_runners=2) - config.resources(num_gpus=0) - config.training(vf_loss_coeff=0.01) - - stop = { - TRAINING_ITERATION: args.stop_iters, - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, - } - - tuner = tune.Tuner( - args.run, - param_space=config.to_dict(), - run_config=air.RunConfig( - stop=stop, - ), - ) - results = tuner.fit() - - if args.as_test: - check_learning_achieved(results, args.stop_reward) - ray.shutdown() diff --git a/rllib/examples/custom_model_api.py b/rllib/examples/custom_model_api.py deleted file mode 100644 index e1e6705bbf771..0000000000000 --- a/rllib/examples/custom_model_api.py +++ /dev/null @@ -1,109 +0,0 @@ -# @OldAPIStack -import argparse -from gymnasium.spaces import Box, Discrete -import numpy as np - -from ray.rllib.examples._old_api_stack.models.custom_model_api import ( - DuelingQModel, - TorchDuelingQModel, - ContActionQModel, - TorchContActionQModel, -) -from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.framework import try_import_tf, try_import_torch - -tf1, tf, tfv = try_import_tf() -torch, _ = try_import_torch() - -parser = argparse.ArgumentParser() -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", -) - -if __name__ == "__main__": - args = parser.parse_args() - - # Test API wrapper for dueling Q-head. - - obs_space = Box(-1.0, 1.0, (3,)) - action_space = Discrete(3) - - # Run in eager mode for value checking and debugging. - tf1.enable_eager_execution() - - # __sphinx_doc_model_construct_1_begin__ - my_dueling_model = ModelCatalog.get_model_v2( - obs_space=obs_space, - action_space=action_space, - num_outputs=action_space.n, - model_config=MODEL_DEFAULTS, - framework=args.framework, - # Providing the `model_interface` arg will make the factory - # wrap the chosen default model with our new model API class - # (DuelingQModel). This way, both `forward` and `get_q_values` - # are available in the returned class. - model_interface=DuelingQModel - if args.framework != "torch" - else TorchDuelingQModel, - name="dueling_q_model", - ) - # __sphinx_doc_model_construct_1_end__ - - batch_size = 10 - input_ = np.array([obs_space.sample() for _ in range(batch_size)]) - # Note that for PyTorch, you will have to provide torch tensors here. - if args.framework == "torch": - input_ = torch.from_numpy(input_) - - input_dict = SampleBatch(obs=input_, _is_training=False) - out, state_outs = my_dueling_model(input_dict=input_dict) - assert out.shape == (10, 256) - # Pass `out` into `get_q_values` - q_values = my_dueling_model.get_q_values(out) - assert q_values.shape == (10, action_space.n) - - # Test API wrapper for single value Q-head from obs/action input. - - obs_space = Box(-1.0, 1.0, (3,)) - action_space = Box(-1.0, -1.0, (2,)) - - # __sphinx_doc_model_construct_2_begin__ - my_cont_action_q_model = ModelCatalog.get_model_v2( - obs_space=obs_space, - action_space=action_space, - num_outputs=2, - model_config=MODEL_DEFAULTS, - framework=args.framework, - # Providing the `model_interface` arg will make the factory - # wrap the chosen default model with our new model API class - # (DuelingQModel). This way, both `forward` and `get_q_values` - # are available in the returned class. - model_interface=ContActionQModel - if args.framework != "torch" - else TorchContActionQModel, - name="cont_action_q_model", - ) - # __sphinx_doc_model_construct_2_end__ - - batch_size = 10 - input_ = np.array([obs_space.sample() for _ in range(batch_size)]) - - # Note that for PyTorch, you will have to provide torch tensors here. - if args.framework == "torch": - input_ = torch.from_numpy(input_) - - input_dict = SampleBatch(obs=input_, _is_training=False) - # Note that for PyTorch, you will have to provide torch tensors here. - out, state_outs = my_cont_action_q_model(input_dict=input_dict) - assert out.shape == (10, 256) - # Pass `out` and an action into `my_cont_action_q_model` - action = np.array([action_space.sample() for _ in range(batch_size)]) - if args.framework == "torch": - action = torch.from_numpy(action) - - q_value = my_cont_action_q_model.get_single_q_value(out, action) - assert q_value.shape == (10, 1) diff --git a/rllib/examples/custom_model_loss_and_metrics.py b/rllib/examples/custom_model_loss_and_metrics.py deleted file mode 100644 index ccb3d8e1acd07..0000000000000 --- a/rllib/examples/custom_model_loss_and_metrics.py +++ /dev/null @@ -1,117 +0,0 @@ -# @OldAPIStack - -# Users should just inherit the Learner and extend the loss_fn. -# TODO (sven): Move this example script to `examples/learners/...` - -"""Example of using custom_loss() with an imitation learning loss under the Policy -and ModelV2 API. - -The default input file is too small to learn a good policy, but you can -generate new experiences for IL training as follows: - -To generate experiences: -$ ./train.py --run=PG --config='{"output": "/tmp/cartpole"}' --env=CartPole-v1 - -To train on experiences with joint PG + IL loss: -$ python custom_loss.py --input-files=/tmp/cartpole -""" - -import argparse -from pathlib import Path -import os - -import ray -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.core import DEFAULT_MODULE_ID -from ray.rllib.examples._old_api_stack.models.custom_loss_model import ( - CustomLossModel, - TorchCustomLossModel, -) -from ray.rllib.models import ModelCatalog -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY -from ray.tune.registry import get_trainable_cls - -tf1, tf, tfv = try_import_tf() - -parser = argparse.ArgumentParser() -parser.add_argument( - "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use." -) -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", -) -parser.add_argument("--stop-iters", type=int, default=200) -parser.add_argument( - "--input-files", - type=str, - default=os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../tests/data/cartpole/small.json" - ), -) - -if __name__ == "__main__": - ray.init() - args = parser.parse_args() - - # Bazel makes it hard to find files specified in `args` (and `data`). - # Look for them here. - if not os.path.exists(args.input_files): - # This script runs in the ray/rllib/examples dir. - rllib_dir = Path(__file__).parent.parent - input_dir = rllib_dir.absolute().joinpath(args.input_files) - args.input_files = str(input_dir) - - ModelCatalog.register_custom_model( - "custom_loss", - TorchCustomLossModel if args.framework == "torch" else CustomLossModel, - ) - - config = ( - get_trainable_cls(args.run) - .get_default_config() - .environment("CartPole-v1") - .framework(args.framework) - .env_runners(num_env_runners=0) - .training( - model={ - "custom_model": "custom_loss", - "custom_model_config": { - "input_files": args.input_files, - }, - }, - ) - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - stop = {TRAINING_ITERATION: args.stop_iters} - - tuner = tune.Tuner( - args.run, - param_space=config, - run_config=air.RunConfig(stop=stop, verbose=1), - ) - results = tuner.fit() - info = results.get_best_result().metrics["info"] - - # Torch metrics structure. - if args.framework == "torch": - assert LEARNER_STATS_KEY in info[LEARNER_INFO][DEFAULT_MODULE_ID] - assert "model" in info[LEARNER_INFO][DEFAULT_MODULE_ID] - assert "custom_metrics" in info[LEARNER_INFO][DEFAULT_MODULE_ID] - - # TODO: (sven) Make sure the metrics structure gets unified between - # tf and torch. Tf should work like current torch: - # info: - # learner: - # [policy_id] - # learner_stats: [return values of policy's `stats_fn`] - # model: [return values of ModelV2's `metrics` method] - # custom_metrics: [return values of callback: `on_learn_on_batch`] - else: - assert "model" in info[LEARNER_INFO][DEFAULT_MODULE_ID][LEARNER_STATS_KEY] diff --git a/rllib/examples/learners/ppo_tuner.py b/rllib/examples/learners/ppo_tuner.py deleted file mode 100644 index a27e292b9efa7..0000000000000 --- a/rllib/examples/learners/ppo_tuner.py +++ /dev/null @@ -1,61 +0,0 @@ -import argparse - -import ray -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.ppo import PPOConfig - -LEARNER_CONFIG = { - "remote-cpu": {"num_learners": 1}, - "remote-gpu": {"num_learners": 1, "num_gpus_per_learner": 1}, - "multi-gpu-ddp": { - "num_learners": 2, - "num_gpus_per_learner": 1, - }, - "local-cpu": {}, - "local-gpu": {"num_gpus_per_learner": 1}, -} - - -def _parse_args(): - - parser = argparse.ArgumentParser() - parser.add_argument( - "--config", - type=str, - default="local-cpu", - ) - - parser.add_argument( - "--framework", - choices=["tf2", "torch"], # tf will be deprecated with the new Learner stack - default="torch", - ) - - return parser.parse_args() - - -if __name__ == "__main__": - args = _parse_args() - - ray.init() - - config = ( - PPOConfig() - .framework(args.framework) - .environment("CartPole-v1") - .learners(**LEARNER_CONFIG[args.config]) - ) - - print("Testing with learner config: ", LEARNER_CONFIG[args.config]) - print("Testing with framework: ", args.framework) - print("-" * 80) - tuner = tune.Tuner( - "PPO", - param_space=config.to_dict(), - run_config=air.RunConfig( - stop={TRAINING_ITERATION: 1}, - failure_config=air.FailureConfig(fail_fast="raise"), - ), - ) - tuner.fit() diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index bcc960158a078..674895d4b3892 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -835,6 +835,11 @@ RAY_CONFIG(std::string, REDIS_SERVER_NAME, "") // it will apply to all methods. RAY_CONFIG(std::string, testing_asio_delay_us, "") +/// To use this, simply do +/// export +/// RAY_testing_rpc_failure="method1=max_num_failures,method2=max_num_failures" +RAY_CONFIG(std::string, testing_rpc_failure, "") + /// The following are configs for the health check. They are borrowed /// from k8s health probe (shorturl.at/jmTY3) /// The delay to send the first health check. diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 191788e7e0458..30042635dee7c 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -238,6 +238,7 @@ void CoreWorkerProcessImpl::InitializeSystemConfig() { RayConfig::instance().initialize(promise.get_future().get()); ray::asio::testing::init(); + ray::rpc::testing::init(); } void CoreWorkerProcessImpl::RunWorkerTaskExecutionLoop() { diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index de581a5e9405a..0c30514c1e32b 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -342,8 +342,8 @@ GcsActorManager::GcsActorManager( actor_gc_delay_(RayConfig::instance().gcs_actor_table_min_duration_ms()) { RAY_CHECK(worker_client_factory_); RAY_CHECK(destroy_owned_placement_group_if_needed_); - actor_state_counter_.reset( - new CounterMap>()); + actor_state_counter_ = std::make_shared< + CounterMap>>(); actor_state_counter_->SetOnChangeCallback( [this](const std::pair key) mutable { int64_t num_actors = actor_state_counter_->Get(key); diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index ce0599c3cc4bd..f68a764f600cc 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -15,17 +15,21 @@ #include "ray/gcs/gcs_server/gcs_job_manager.h" #include "ray/gcs/pb_util.h" +#include "ray/stats/metric.h" namespace ray { namespace gcs { void GcsJobManager::Initialize(const GcsInitData &gcs_init_data) { - for (auto &pair : gcs_init_data.Jobs()) { - const auto &job_id = pair.first; - const auto &job_table_data = pair.second; + for (const auto &[job_id, job_table_data] : gcs_init_data.Jobs()) { cached_job_configs_[job_id] = std::make_shared(job_table_data.config()); function_manager_.AddJobReference(job_id); + + // Recover [running_job_ids_] from storage. + if (!job_table_data.is_dead()) { + running_job_ids_.insert(job_id); + } } } @@ -82,28 +86,38 @@ void GcsJobManager::HandleAddJob(rpc::AddJobRequest request, auto time = current_sys_time_ms(); mutable_job_table_data.set_start_time(time); mutable_job_table_data.set_timestamp(time); - JobID job_id = JobID::FromBinary(mutable_job_table_data.job_id()); + const JobID job_id = JobID::FromBinary(mutable_job_table_data.job_id()); RAY_LOG(INFO) << "Adding job, job id = " << job_id << ", driver pid = " << mutable_job_table_data.driver_pid(); - auto on_done = [this, job_id, mutable_job_table_data, reply, send_reply_callback]( - const Status &status) { + auto on_done = [this, + job_id, + job_table_data = mutable_job_table_data, + reply, + send_reply_callback = + std::move(send_reply_callback)](const Status &status) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + if (!status.ok()) { RAY_LOG(ERROR) << "Failed to add job, job id = " << job_id - << ", driver pid = " << mutable_job_table_data.driver_pid(); + << ", driver pid = " << job_table_data.driver_pid(); } else { - RAY_CHECK_OK(gcs_publisher_->PublishJob(job_id, mutable_job_table_data, nullptr)); - if (mutable_job_table_data.config().has_runtime_env_info()) { - runtime_env_manager_.AddURIReference( - job_id.Hex(), mutable_job_table_data.config().runtime_env_info()); + RAY_CHECK_OK(gcs_publisher_->PublishJob(job_id, job_table_data, /*done=*/nullptr)); + if (job_table_data.config().has_runtime_env_info()) { + runtime_env_manager_.AddURIReference(job_id.Hex(), + job_table_data.config().runtime_env_info()); } function_manager_.AddJobReference(job_id); RAY_LOG(INFO) << "Finished adding job, job id = " << job_id - << ", driver pid = " << mutable_job_table_data.driver_pid(); + << ", driver pid = " << job_table_data.driver_pid(); cached_job_configs_[job_id] = - std::make_shared(mutable_job_table_data.config()); + std::make_shared(job_table_data.config()); + + // Intentionally not checking return value, since the function could be invoked for + // multiple times and requires idempotency (i.e. due to retry). + running_job_ids_.insert(job_id); } - WriteDriverJobExportEvent(mutable_job_table_data); + WriteDriverJobExportEvent(job_table_data); GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; @@ -122,7 +136,10 @@ void GcsJobManager::MarkJobAsFinished(rpc::JobTableData job_table_data, job_table_data.set_timestamp(time); job_table_data.set_end_time(time); job_table_data.set_is_dead(true); - auto on_done = [this, job_id, job_table_data, done_callback](const Status &status) { + auto on_done = [this, job_id, job_table_data, done_callback = std::move(done_callback)]( + const Status &status) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + if (!status.ok()) { RAY_LOG(ERROR) << "Failed to mark job state, job id = " << job_id; } else { @@ -133,6 +150,13 @@ void GcsJobManager::MarkJobAsFinished(rpc::JobTableData job_table_data, } function_manager_.RemoveJobReference(job_id); WriteDriverJobExportEvent(job_table_data); + + // Update running job status. + auto iter = running_job_ids_.find(job_id); + RAY_CHECK(iter != running_job_ids_.end()); + running_job_ids_.erase(iter); + ++finished_jobs_count_; + done_callback(status); }; @@ -147,21 +171,30 @@ void GcsJobManager::HandleMarkJobFinished(rpc::MarkJobFinishedRequest request, rpc::SendReplyCallback send_reply_callback) { const JobID job_id = JobID::FromBinary(request.job_id()); - auto send_reply = [send_reply_callback, reply](Status status) { + auto send_reply = [send_reply_callback = std::move(send_reply_callback), + reply](Status status) { GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; Status status = gcs_table_storage_->JobTable().Get( job_id, - [this, job_id, send_reply](Status status, + [this, job_id, send_reply](const Status &status, const std::optional &result) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + if (status.ok() && result) { MarkJobAsFinished(*result, send_reply); - } else { + return; + } + + if (!result.has_value()) { RAY_LOG(ERROR) << "Tried to mark job " << job_id << " as finished, but there was no record of it starting!"; - send_reply(status); + } else if (!status.ok()) { + RAY_LOG(ERROR) << "Fails to mark job " << job_id << " as finished due to " + << status; } + send_reply(status); }); if (!status.ok()) { send_reply(status); @@ -239,6 +272,8 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, }; auto on_done = [this, filter_ok, request, reply, send_reply_callback, limit]( const absl::flat_hash_map &&result) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + // Internal KV keys for jobs that were submitted via the Ray Job API. std::vector job_api_data_keys; @@ -420,11 +455,13 @@ void GcsJobManager::OnNodeDead(const NodeID &node_id) { << "Node failed, mark all jobs from this node as finished"; auto on_done = [this, node_id](const absl::flat_hash_map &result) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + // If job is not dead and from driver in current node, then mark it as finished for (auto &data : result) { if (!data.second.is_dead() && NodeID::FromBinary(data.second.driver_address().raylet_id()) == node_id) { - RAY_LOG(DEBUG) << "Marking job: " << data.first << " as finished"; + RAY_LOG(DEBUG).WithField(data.first) << "Marking job as finished"; MarkJobAsFinished(data.second, [data](Status status) { if (!status.ok()) { RAY_LOG(WARNING) << "Failed to mark job as finished. Status: " << status; @@ -438,5 +475,10 @@ void GcsJobManager::OnNodeDead(const NodeID &node_id) { RAY_CHECK_OK(gcs_table_storage_->JobTable().GetAll(on_done)); } +void GcsJobManager::RecordMetrics() { + ray::stats::STATS_running_jobs.Record(running_job_ids_.size()); + ray::stats::STATS_finished_jobs.Record(finished_jobs_count_); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h index 891e7ea20d25d..95f43c7e27ad2 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/ray/gcs/gcs_server/gcs_job_manager.h @@ -14,6 +14,14 @@ #pragma once +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "ray/common/runtime_env_manager.h" #include "ray/gcs/gcs_server/gcs_function_manager.h" #include "ray/gcs/gcs_server/gcs_init_data.h" @@ -23,6 +31,7 @@ #include "ray/rpc/worker/core_worker_client.h" #include "ray/rpc/worker/core_worker_client_pool.h" #include "ray/util/event.h" +#include "ray/util/thread_checker.h" namespace ray { namespace gcs { @@ -88,7 +97,27 @@ class GcsJobManager : public rpc::JobInfoHandler { void WriteDriverJobExportEvent(rpc::JobTableData job_data) const; + /// Record metrics. + /// For job manager, (1) running jobs count gauge and (2) new finished jobs (whether + /// succeed or fail) will be reported periodically. + void RecordMetrics(); + private: + void ClearJobInfos(const rpc::JobTableData &job_data); + + void MarkJobAsFinished(rpc::JobTableData job_table_data, + std::function done_callback); + + // Used to validate invariants for threading; for example, all callbacks are executed on + // the same thread. + ThreadChecker thread_checker_; + + // Running Job IDs, used to report metrics. + absl::flat_hash_set running_job_ids_; + + // Number of finished jobs since start of this GCS Server, used to report metrics. + int64_t finished_jobs_count_ = 0; + std::shared_ptr gcs_table_storage_; std::shared_ptr gcs_publisher_; @@ -104,11 +133,6 @@ class GcsJobManager : public rpc::JobInfoHandler { /// The cached core worker clients which are used to communicate with workers. rpc::CoreWorkerClientPool core_worker_clients_; - - void ClearJobInfos(const rpc::JobTableData &job_data); - - void MarkJobAsFinished(rpc::JobTableData job_table_data, - std::function done_callback); }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 174785681859b..805f2f521ed6e 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -789,6 +789,7 @@ void GcsServer::RecordMetrics() const { gcs_actor_manager_->RecordMetrics(); gcs_placement_group_manager_->RecordMetrics(); gcs_task_manager_->RecordMetrics(); + gcs_job_manager_->RecordMetrics(); execute_after( main_service_, [this] { RecordMetrics(); }, diff --git a/src/ray/gcs/gcs_server/gcs_server_main.cc b/src/ray/gcs/gcs_server/gcs_server_main.cc index 137efbaf9dd5d..18d7b83d896e4 100644 --- a/src/ray/gcs/gcs_server/gcs_server_main.cc +++ b/src/ray/gcs/gcs_server/gcs_server_main.cc @@ -70,6 +70,7 @@ int main(int argc, char *argv[]) { RayConfig::instance().initialize(config_list); ray::asio::testing::init(); + ray::rpc::testing::init(); // IO Service for main loop. instrumented_io_context main_service; diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index b05cffc6beb3b..16133f2901389 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -111,8 +111,8 @@ class GcsTableWithJobId : public GcsTable { /// \param key The key that will be written to the table. The job id can be obtained /// from the key. /// \param value The value of the key that will be written to the table. - /// \param callback Callback that will be called after write finishes. - /// \return Status + /// \param callback Callback that will be called after write finishes, whether it + /// succeeds or not. \return Status for issuing the asynchronous write operation. Status Put(const Key &key, const Data &value, const StatusCallback &callback) override; /// Get all the data of the specified job id from the table asynchronously. diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index dbd2fee683ee9..fb20006d57b96 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -710,6 +710,7 @@ message ActorCreationTaskSpec { // The dynamic options used in the worker command when starting a worker process for // an actor creation task. If the list isn't empty, the options will be used to replace // the placeholder string `RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER` in the worker command. + // Used by Java workers for JVM options. repeated string dynamic_worker_options = 5; // The max number of concurrent calls for default concurrency group of this actor. int32 max_concurrency = 6; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 8cc0c1f08ef1a..2b30f9068b6b6 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -275,6 +275,7 @@ int main(int argc, char *argv[]) { RAY_CHECK(stored_raylet_config.has_value()); RayConfig::instance().initialize(*stored_raylet_config); ray::asio::testing::init(); + ray::rpc::testing::init(); // Core worker tries to kill child processes when it exits. But they can't do // it perfectly: if the core worker is killed by SIGKILL, the child processes diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 6770be6bd70ec..82c7476b17fcd 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -15,6 +15,7 @@ #include "ray/raylet/worker.h" #include +#include #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/raylet.h" @@ -42,7 +43,7 @@ Worker::Worker(const JobID &job_id, ip_address_(ip_address), assigned_port_(-1), port_(-1), - connection_(connection), + connection_(std::move(connection)), assigned_job_id_(job_id), runtime_env_hash_(runtime_env_hash), bundle_id_(std::make_pair(PlacementGroupID::Nil(), -1)), @@ -129,7 +130,12 @@ void Worker::Connect(std::shared_ptr rpc_client) } } -void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; } +void Worker::AssignTaskId(const TaskID &task_id) { + assigned_task_id_ = task_id; + if (!task_id.IsNil()) { + task_assign_time_ = absl::Now(); + } +} const TaskID &Worker::GetAssignedTaskId() const { return assigned_task_id_; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index f59acb827ab15..9166eea619fca 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -233,7 +233,6 @@ class Worker : public WorkerInterface { RAY_CHECK(!task_spec.IsActorTask()); SetIsActorWorker(task_spec.IsActorCreationTask()); assigned_task_ = assigned_task; - task_assign_time_ = absl::Now(); root_detached_actor_id_ = assigned_task.GetTaskSpecification().RootDetachedActorId(); } diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index de417b23693d5..8293be8e29c32 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -62,6 +62,20 @@ bool RemoveWorker( const std::shared_ptr &worker) { return worker_pool.erase(worker) > 0; } + +// If both `ask` and `have` are set, they must match. If either of them is not set, it +// is considered a match. +bool OptionalMatches(const std::optional &ask, const std::optional &have) { + return !ask.has_value() || !have.has_value() || ask.value() == have.value(); +} + +// Similar to OptionalMatches, but for JobID or ActorID. +// TODO(ryw): use std::optional. +template +bool IdMatches(const IDType &ask, const IDType &have) { + return ask.IsNil() || have.IsNil() || ask == have; +} + } // namespace namespace ray { @@ -180,11 +194,10 @@ void WorkerPool::SetRuntimeEnvAgentClient( if (!runtime_env_agent_client) { RAY_LOG(FATAL) << "SetRuntimeEnvAgentClient requires non empty pointer"; } - runtime_env_agent_client_ = runtime_env_agent_client; + runtime_env_agent_client_ = std::move(runtime_env_agent_client); } -void WorkerPool::PopWorkerCallbackAsync(const TaskSpecification &task_spec, - const PopWorkerCallback &callback, +void WorkerPool::PopWorkerCallbackAsync(PopWorkerCallback callback, std::shared_ptr worker, PopWorkerStatus status) { // This method shouldn't be invoked when runtime env creation has failed because @@ -193,34 +206,17 @@ void WorkerPool::PopWorkerCallbackAsync(const TaskSpecification &task_spec, RAY_CHECK(status != PopWorkerStatus::RuntimeEnvCreationFailed); // Call back this function asynchronously to make sure executed in different stack. io_service_->post( - [this, task_spec, callback, worker, status]() { - PopWorkerCallbackInternal(task_spec, callback, worker, status); + [this, callback = std::move(callback), worker = std::move(worker), status]() { + PopWorkerCallbackInternal(callback, worker, status); }, "WorkerPool.PopWorkerCallback"); } -void WorkerPool::PopWorkerCallbackInternal(const TaskSpecification &task_spec, - const PopWorkerCallback &callback, +void WorkerPool::PopWorkerCallbackInternal(const PopWorkerCallback &callback, std::shared_ptr worker, PopWorkerStatus status) { RAY_CHECK(callback); - auto used = false; - if (worker && finished_jobs_.contains(task_spec.JobId()) && - task_spec.RootDetachedActorId().IsNil()) { - // When a job finishes, node manager will kill leased workers one time - // and worker pool will kill idle workers periodically. - // The current worker is already removed from the idle workers - // but hasn't been added to the leased workers since the callback is not called yet. - // We shouldn't add this worker to the leased workers since killing leased workers - // for this finished job may already happen and won't happen again (this is one time) - // so it will cause a process leak. - // Instead we fail the PopWorker and add the worker back to the idle workers so it can - // be killed later. - RAY_CHECK(status == PopWorkerStatus::OK); - callback(nullptr, PopWorkerStatus::JobFinished, ""); - } else { - used = callback(worker, status, /*runtime_env_setup_error_message*/ ""); - } + auto used = callback(worker, status, /*runtime_env_setup_error_message=*/""); if (worker && !used) { // The invalid worker not used, restore it to worker pool. PushWorker(worker); @@ -511,8 +507,7 @@ std::tuple WorkerPool::StartWorkerProcess( if (!IsIOWorkerType(worker_type)) { AdjustWorkerOomScore(proc.GetId()); } - MonitorStartingWorkerProcess( - proc, worker_startup_token_counter_, language, worker_type); + MonitorStartingWorkerProcess(worker_startup_token_counter_, language, worker_type); AddWorkerProcess(state, worker_type, proc, start, runtime_env_info, dynamic_options); StartupToken worker_startup_token = worker_startup_token_counter_; update_worker_startup_token_counter(); @@ -544,8 +539,7 @@ void WorkerPool::AdjustWorkerOomScore(pid_t pid) const { #endif } -void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, - StartupToken proc_startup_token, +void WorkerPool::MonitorStartingWorkerProcess(StartupToken proc_startup_token, const Language &language, const rpc::WorkerType worker_type) { auto timer = std::make_shared( @@ -553,7 +547,7 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, boost::posix_time::seconds( RayConfig::instance().worker_register_timeout_seconds())); // Capture timer in lambda to copy it once, so that it can avoid destructing timer. - timer->async_wait([timer, language, proc = proc, proc_startup_token, worker_type, this]( + timer->async_wait([timer, language, proc_startup_token, worker_type, this]( const boost::system::error_code e) mutable { // check the error code. auto &state = this->GetStateForLanguage(language); @@ -562,26 +556,17 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, auto it = state.worker_processes.find(proc_startup_token); if (it != state.worker_processes.end() && it->second.is_pending_registration) { RAY_LOG(ERROR) - << "Some workers of the worker process(" << proc.GetId() + << "Some workers of the worker process(" << it->second.proc.GetId() << ") have not registered within the timeout. " - << (proc.IsAlive() + << (it->second.proc.IsAlive() ? "The process is still alive, probably it's hanging during start." : "The process is dead, probably it crashed during start."); - if (proc.IsAlive()) { - proc.Kill(); + if (it->second.proc.IsAlive()) { + it->second.proc.Kill(); } - PopWorkerStatus status = PopWorkerStatus::WorkerPendingRegistration; process_failed_pending_registration_++; - bool found; - bool used; - InvokePopWorkerCallbackForProcess(state.starting_workers_to_tasks, - proc_startup_token, - nullptr, - status, - &found, - &used); DeleteRuntimeEnvIfPossible(it->second.runtime_env_info.serialized_runtime_env()); RemoveWorkerProcess(state, proc_startup_token); if (IsIOWorkerType(worker_type)) { @@ -592,13 +577,33 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, // We may have places to start more workers now. TryStartIOWorkers(language); if (worker_type == rpc::WorkerType::WORKER) { - TryPendingPopWorkerRequests(language); + TryPendingStartRequests(language); } starting_worker_timeout_callback_(); } }); } +void WorkerPool::MonitorPopWorkerRequestForRegistration( + std::shared_ptr pop_worker_request) { + auto timer = std::make_shared( + *io_service_, + boost::posix_time::seconds( + RayConfig::instance().worker_register_timeout_seconds())); + // Capture timer in lambda to copy it once, so that it can avoid destructing timer. + timer->async_wait([timer, pop_worker_request = std::move(pop_worker_request), this]( + const boost::system::error_code e) mutable { + auto &state = GetStateForLanguage(pop_worker_request->language); + auto &requests = state.pending_registration_requests; + auto it = std::find(requests.begin(), requests.end(), pop_worker_request); + if (it != requests.end()) { + // Fail the task... + PopWorkerStatus status = PopWorkerStatus::WorkerPendingRegistration; + PopWorkerCallbackAsync(pop_worker_request->callback, nullptr, status); + } + }); +} + Process WorkerPool::StartProcess(const std::vector &worker_command_args, const ProcessEnvironment &env) { if (RAY_LOG_ENABLED(DEBUG)) { @@ -978,73 +983,65 @@ void WorkerPool::PopDeleteWorker( } } -void WorkerPool::InvokePopWorkerCallbackForProcess( - absl::flat_hash_map - &starting_workers_to_tasks, - StartupToken startup_token, - const std::shared_ptr &worker, - const PopWorkerStatus &status, - bool *found, - bool *worker_used) { - *found = false; - *worker_used = false; - auto it = starting_workers_to_tasks.find(startup_token); - if (it != starting_workers_to_tasks.end()) { - *found = true; - const auto &callback = it->second.callback; - RAY_CHECK(callback); - // This method shouldn't be invoked when runtime env creation has failed because - // when runtime env is failed to be created, they are all - // invoking the callback immediately. - RAY_CHECK(status != PopWorkerStatus::RuntimeEnvCreationFailed); - if (worker && finished_jobs_.contains(it->second.task_spec.JobId()) && - it->second.task_spec.RootDetachedActorId().IsNil()) { - // If the job has finished, we should fail the PopWorker callback - // and add the worker back to the idle workers so it can be killed later. - // This doesn't apply to detached actor and its descendants - // since they can outlive the job. - RAY_CHECK(status == PopWorkerStatus::OK); - callback(nullptr, PopWorkerStatus::JobFinished, ""); - } else { - *worker_used = callback(worker, status, /*runtime_env_setup_error_message*/ ""); - } - starting_workers_to_tasks.erase(it); - } -} - void WorkerPool::PushWorker(const std::shared_ptr &worker) { // Since the worker is now idle, unset its assigned task ID. RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; + + // Find a task that this worker can fit. If there's none, put it in the idle pool. + // First find in pending_registration_requests, then in pending_start_requests. + std::shared_ptr pop_worker_request = nullptr; auto &state = GetStateForLanguage(worker->GetLanguage()); - bool found; - bool used; - InvokePopWorkerCallbackForProcess(state.starting_workers_to_tasks, - worker->GetStartupToken(), - worker, - PopWorkerStatus::OK, - &found, - &used); - RAY_LOG(DEBUG) << "PushWorker " << worker->WorkerId() << " used: " << used; - if (!used) { - // Put the worker to the idle pool. + { + auto it = std::find_if( + state.pending_registration_requests.begin(), + state.pending_registration_requests.end(), + [this, &worker](const std::shared_ptr &pop_worker_request) { + return WorkerFitsForTask(*worker, *pop_worker_request) == + WorkerUnfitForTaskReason::NONE; + }); + if (it != state.pending_registration_requests.end()) { + pop_worker_request = *it; + state.pending_registration_requests.erase(it); + } + } + if (!pop_worker_request) { + auto it = std::find_if( + state.pending_start_requests.begin(), + state.pending_start_requests.end(), + [this, &worker](const std::shared_ptr &pop_worker_request) { + return WorkerFitsForTask(*worker, *pop_worker_request) == + WorkerUnfitForTaskReason::NONE; + }); + if (it != state.pending_start_requests.end()) { + pop_worker_request = *it; + state.pending_start_requests.erase(it); + } + } + + if (pop_worker_request) { + bool used = pop_worker_request->callback(worker, PopWorkerStatus::OK, ""); + if (!used) { + // Retry PushWorker. Maybe it can be used by other tasks. + // Can we have tail call optimization for this? :) + return PushWorker(worker); + } + } else { state.idle.insert(worker); auto now = get_time_(); - if (found) { - // If the worker was just started, then we should consider it first when - // choosing which idle workers to kill because it is cold. - idle_of_all_languages_.push_front(std::make_pair(worker, now)); + if (worker->GetAssignedTaskTime() == absl::Time()) { + // If the worker never held any tasks, then we should consider it first when + // choosing which idle workers to kill because it is not warmed up and is slower + // than those workers who served tasks before. + // See https://github.com/ray-project/ray/pull/36766 + idle_of_all_languages_.emplace_front(worker, now); } else { idle_of_all_languages_.emplace_back(worker, now); } - } else if (!found) { - RAY_LOG(INFO) << "Worker not returned to the idle pool after being used. This may " - "cause a worker leak, worker id:" - << worker->WorkerId(); } // We either have an idle worker or a slot to start a new worker. if (worker->GetWorkerType() == rpc::WorkerType::WORKER) { - TryPendingPopWorkerRequests(worker->GetLanguage()); + TryPendingStartRequests(worker->GetLanguage()); } } @@ -1162,102 +1159,164 @@ void WorkerPool::KillIdleWorker(std::shared_ptr idle_worker, } WorkerUnfitForTaskReason WorkerPool::WorkerFitsForTask( - const WorkerInterface &worker, const TaskSpecification &task_spec) const { + const WorkerInterface &worker, const PopWorkerRequest &pop_worker_request) const { if (worker.IsDead()) { return WorkerUnfitForTaskReason::OTHERS; } - if (worker.GetLanguage() != task_spec.GetLanguage()) { + // These workers are exiting. So skip them. + if (pending_exit_idle_workers_.contains(worker.WorkerId())) { + return WorkerUnfitForTaskReason::OTHERS; + } + if (worker.GetLanguage() != pop_worker_request.language) { return WorkerUnfitForTaskReason::OTHERS; } - // Don't allow worker reuse across jobs or root detached actors. Reuse worker with - // unassigned job_id and root detached actor id is OK. - JobID job_id = worker.GetAssignedJobId(); - if (!job_id.IsNil() && job_id != task_spec.JobId()) { - return WorkerUnfitForTaskReason::ROOT_MISMATCH; + if (worker.GetWorkerType() != pop_worker_request.worker_type) { + return WorkerUnfitForTaskReason::OTHERS; } - ActorID root_detached_actor_id = worker.GetRootDetachedActorId(); - if (!root_detached_actor_id.IsNil() && - root_detached_actor_id != task_spec.RootDetachedActorId()) { + + if (!IdMatches(pop_worker_request.root_detached_actor_id, + worker.GetRootDetachedActorId())) { return WorkerUnfitForTaskReason::ROOT_MISMATCH; } - auto is_gpu = worker.GetIsGpu(); - if (is_gpu.has_value()) { - bool task_is_gpu = - task_spec.GetRequiredResources().Get(scheduling::ResourceID::GPU()) > 0; - if (is_gpu.value() != task_is_gpu) { - return WorkerUnfitForTaskReason::OTHERS; + // Only compare job id for actors not rooted to a detached actor. + if (pop_worker_request.root_detached_actor_id.IsNil()) { + if (!IdMatches(pop_worker_request.job_id, worker.GetAssignedJobId())) { + return WorkerUnfitForTaskReason::ROOT_MISMATCH; } } - auto is_actor_worker = worker.GetIsActorWorker(); - if (is_actor_worker.has_value() && - is_actor_worker.value() != task_spec.IsActorCreationTask()) { + // If the request asks for a is_gpu, and the worker is assigned a different is_gpu, + // then skip it. + if (!OptionalMatches(pop_worker_request.is_gpu, worker.GetIsGpu())) { + return WorkerUnfitForTaskReason::OTHERS; + } + // If the request asks for a is_actor_worker, and the worker is assigned a different + // is_actor_worker, then skip it. + if (!OptionalMatches(pop_worker_request.is_actor_worker, worker.GetIsActorWorker())) { return WorkerUnfitForTaskReason::OTHERS; } // TODO(clarng): consider re-using worker that has runtime envionrment // if the task doesn't require one. - if (worker.GetRuntimeEnvHash() != task_spec.GetRuntimeEnvHash()) { + if (worker.GetRuntimeEnvHash() != pop_worker_request.runtime_env_hash) { return WorkerUnfitForTaskReason::RUNTIME_ENV_MISMATCH; } // Skip if the dynamic_options doesn't match. if (LookupWorkerDynamicOptions(worker.GetStartupToken()) != - task_spec.DynamicWorkerOptionsOrEmpty()) { + pop_worker_request.dynamic_options) { return WorkerUnfitForTaskReason::DYNAMIC_OPTIONS_MISMATCH; } - // These workers are exiting. So skip them. - if (pending_exit_idle_workers_.contains(worker.WorkerId())) { - return WorkerUnfitForTaskReason::OTHERS; - } return WorkerUnfitForTaskReason::NONE; } -void WorkerPool::PopWorker(const TaskSpecification &task_spec, - const PopWorkerCallback &callback) { - RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId() << " task name " - << task_spec.FunctionDescriptor()->ToString(); - auto &state = GetStateForLanguage(task_spec.GetLanguage()); +void WorkerPool::StartNewWorker( + const std::shared_ptr &pop_worker_request) { + auto start_worker_process_fn = [this]( + std::shared_ptr pop_worker_request, + const std::string &serialized_runtime_env_context) { + auto &state = GetStateForLanguage(pop_worker_request->language); + const std::string &serialized_runtime_env = + pop_worker_request->runtime_env_info.serialized_runtime_env(); - std::shared_ptr worker = nullptr; - auto start_worker_process_fn = [this](const TaskSpecification &task_spec, - State &state, - const std::string &serialized_runtime_env_context, - const PopWorkerCallback &callback) { PopWorkerStatus status = PopWorkerStatus::OK; - auto [proc, startup_token] = - StartWorkerProcess(task_spec.GetLanguage(), - rpc::WorkerType::WORKER, - task_spec.JobId(), - &status, - task_spec.DynamicWorkerOptionsOrEmpty(), - task_spec.GetRuntimeEnvHash(), - serialized_runtime_env_context, - task_spec.RuntimeEnvInfo()); + auto [proc, startup_token] = StartWorkerProcess(pop_worker_request->language, + pop_worker_request->worker_type, + pop_worker_request->job_id, + &status, + pop_worker_request->dynamic_options, + pop_worker_request->runtime_env_hash, + serialized_runtime_env_context, + pop_worker_request->runtime_env_info); if (status == PopWorkerStatus::OK) { RAY_CHECK(proc.IsValid()); WarnAboutSize(); - auto task_info = TaskWaitingForWorkerInfo{task_spec, callback}; - state.starting_workers_to_tasks[startup_token] = std::move(task_info); + state.pending_registration_requests.emplace_back(pop_worker_request); + MonitorPopWorkerRequestForRegistration(pop_worker_request); } else if (status == PopWorkerStatus::TooManyStartingWorkerProcesses) { // TODO(jjyao) As an optimization, we don't need to delete the runtime env // but reuse it the next time we retry the request. - DeleteRuntimeEnvIfPossible(task_spec.SerializedRuntimeEnv()); - state.pending_pop_worker_requests.emplace_back( - PopWorkerRequest{task_spec, callback}); + DeleteRuntimeEnvIfPossible(serialized_runtime_env); + state.pending_start_requests.emplace_back(std::move(pop_worker_request)); } else { - DeleteRuntimeEnvIfPossible(task_spec.SerializedRuntimeEnv()); - PopWorkerCallbackAsync(task_spec, callback, nullptr, status); + DeleteRuntimeEnvIfPossible(serialized_runtime_env); + PopWorkerCallbackAsync(std::move(pop_worker_request->callback), nullptr, status); } }; + const std::string &serialized_runtime_env = + pop_worker_request->runtime_env_info.serialized_runtime_env(); + + if (!IsRuntimeEnvEmpty(serialized_runtime_env)) { + // create runtime env. + GetOrCreateRuntimeEnv( + serialized_runtime_env, + pop_worker_request->runtime_env_info.runtime_env_config(), + pop_worker_request->job_id, + [this, start_worker_process_fn, pop_worker_request]( + bool successful, + const std::string &serialized_runtime_env_context, + const std::string &setup_error_message) { + if (successful) { + start_worker_process_fn(pop_worker_request, serialized_runtime_env_context); + } else { + process_failed_runtime_env_setup_failed_++; + pop_worker_request->callback( + nullptr, + PopWorkerStatus::RuntimeEnvCreationFailed, + /*runtime_env_setup_error_message*/ setup_error_message); + } + }); + } else { + start_worker_process_fn(pop_worker_request, ""); + } +} + +void WorkerPool::PopWorker(const TaskSpecification &task_spec, + const PopWorkerCallback &callback) { + RAY_LOG(DEBUG) << "Pop worker for task " << task_spec.TaskId() << " task name " + << task_spec.FunctionDescriptor()->ToString(); // Code path of actor task. RAY_CHECK(!task_spec.IsActorTask()) << "Direct call shouldn't reach here."; + auto pop_worker_request = std::make_shared( + task_spec.GetLanguage(), + rpc::WorkerType::WORKER, + task_spec.JobId(), + task_spec.RootDetachedActorId(), + /*is_gpu=*/task_spec.GetRequiredResources().Get(scheduling::ResourceID::GPU()) > 0, + /*is_actor_worker=*/task_spec.IsActorCreationTask(), + task_spec.RuntimeEnvInfo(), + task_spec.GetRuntimeEnvHash(), + task_spec.DynamicWorkerOptionsOrEmpty(), + [this, task_spec, callback]( + const std::shared_ptr &worker, + PopWorkerStatus status, + const std::string &runtime_env_setup_error_message) -> bool { + // We got a worker suitable for the task. Now let's check if the task is still + // executable. + if (worker && finished_jobs_.contains(task_spec.JobId()) && + task_spec.RootDetachedActorId().IsNil()) { + // When a job finishes, node manager will kill leased workers one time + // and worker pool will kill idle workers periodically. + // The current worker is already removed from the idle workers + // but hasn't been added to the leased workers since the callback is not called + // yet. We shouldn't add this worker to the leased workers since killing leased + // workers for this finished job may already happen and won't happen again (this + // is one time) so it will cause a process leak. Instead we fail the PopWorker + // and add the worker back to the idle workers so it can be killed later. + RAY_CHECK(status == PopWorkerStatus::OK); + callback(nullptr, PopWorkerStatus::JobFinished, ""); + // Not used + return false; + } + return callback(worker, status, runtime_env_setup_error_message); + }); + absl::flat_hash_map skip_reason_count; auto worker_fits_for_task_fn = - [this, &task_spec, &skip_reason_count]( + [this, &pop_worker_request, &skip_reason_count]( const std::pair, int64_t> &pair) -> bool { const auto &worker = pair.first; - WorkerUnfitForTaskReason reason = WorkerFitsForTask(*worker, task_spec); + WorkerUnfitForTaskReason reason = WorkerFitsForTask(*worker, *pop_worker_request); if (reason == WorkerUnfitForTaskReason::NONE) { return true; } @@ -1271,7 +1330,8 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, } return false; }; - + auto &state = GetStateForLanguage(task_spec.GetLanguage()); + std::shared_ptr worker = nullptr; auto good_worker_it = std::find_if(idle_of_all_languages_.rbegin(), idle_of_all_languages_.rend(), worker_fits_for_task_fn); @@ -1284,45 +1344,19 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, idle_of_all_languages_.erase(lit); } + // If there's an idle worker that fits the task, use it. + // Else, start a new worker. if (worker == nullptr) { RAY_LOG(DEBUG) << "No cached worker, cached workers skipped due to " << debug_string(skip_reason_count); - if (task_spec.HasRuntimeEnv()) { - // create runtime env. - RAY_LOG(DEBUG) << "GetOrCreateRuntimeEnv for task " << task_spec.TaskId(); - GetOrCreateRuntimeEnv( - task_spec.SerializedRuntimeEnv(), - task_spec.RuntimeEnvConfig(), - task_spec.JobId(), - [this, start_worker_process_fn, callback, &state, task_spec]( - bool successful, - const std::string &serialized_runtime_env_context, - const std::string &setup_error_message) { - if (successful) { - start_worker_process_fn( - task_spec, state, serialized_runtime_env_context, callback); - } else { - process_failed_runtime_env_setup_failed_++; - callback(nullptr, - PopWorkerStatus::RuntimeEnvCreationFailed, - /*runtime_env_setup_error_message*/ setup_error_message); - RAY_LOG(WARNING) << "Create runtime env failed for task " - << task_spec.TaskId() - << " and couldn't create the worker."; - } - }); - } else { - start_worker_process_fn(task_spec, state, "", callback); - } - } - - if (worker) { + StartNewWorker(pop_worker_request); + } else { RAY_CHECK(worker->GetAssignedJobId().IsNil() || worker->GetAssignedJobId() == task_spec.JobId()); RAY_LOG(DEBUG) << "Re-using worker " << worker->WorkerId() << " for task " << task_spec.DebugString(); stats::NumWorkersStartedFromCache.Record(1); - PopWorkerCallbackAsync(task_spec, callback, worker); + PopWorkerCallbackAsync(pop_worker_request->callback, worker, PopWorkerStatus::OK); } } @@ -1391,7 +1425,7 @@ void WorkerPool::DisconnectWorker(const std::shared_ptr &worker // This may add new workers to state.worker_processes // and invalidate the iterator, do not use `it` // after this call. - TryPendingPopWorkerRequests(worker->GetLanguage()); + TryPendingStartRequests(worker->GetLanguage()); } } @@ -1525,16 +1559,16 @@ void WorkerPool::TryStartIOWorkers(const Language &language) { TryStartIOWorkers(language, rpc::WorkerType::SPILL_WORKER); } -void WorkerPool::TryPendingPopWorkerRequests(const Language &language) { +void WorkerPool::TryPendingStartRequests(const Language &language) { auto &state = GetStateForLanguage(language); - if (state.pending_pop_worker_requests.empty()) { + if (state.pending_start_requests.empty()) { return; } - std::deque pending_pop_worker_requests; - state.pending_pop_worker_requests.swap(pending_pop_worker_requests); - for (const auto &pop_worker_request : pending_pop_worker_requests) { - PopWorker(pop_worker_request.task_spec, pop_worker_request.callback); + std::deque> pending_start_requests; + state.pending_start_requests.swap(pending_start_requests); + for (const auto &request : pending_start_requests) { + StartNewWorker(request); } } @@ -1585,6 +1619,11 @@ std::string WorkerPool::DebugString() const { << " workers: " << entry.second.registered_workers.size(); result << "\n- num " << Language_Name(entry.first) << " drivers: " << entry.second.registered_drivers.size(); + result << "\n- num " << Language_Name(entry.first) + << " pending start requests: " << entry.second.pending_start_requests.size(); + result << "\n- num " << Language_Name(entry.first) + << " pending registration requests: " + << entry.second.pending_registration_requests.size(); result << "\n- num object spill callbacks queued: " << entry.second.spill_io_worker_state.pending_io_tasks.size(); result << "\n- num object restore queued: " diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 58ddc18870656..6d71290ca832f 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -73,7 +73,7 @@ enum PopWorkerStatus { /// \return true if the worker was used. Otherwise, return false /// and the worker will be returned to the worker pool. using PopWorkerCallback = - std::function worker, + std::function &worker, PopWorkerStatus status, const std::string &runtime_env_setup_error_message)>; @@ -437,11 +437,11 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// we didn't start a process. std::tuple StartWorkerProcess( const Language &language, - const rpc::WorkerType worker_type, + rpc::WorkerType worker_type, const JobID &job_id, PopWorkerStatus *status /*output*/, const std::vector &dynamic_options = {}, - const int runtime_env_hash = 0, + int runtime_env_hash = 0, const std::string &serialized_runtime_env_context = "{}", const rpc::RuntimeEnvInfo &runtime_env_info = rpc::RuntimeEnvInfo()); @@ -460,8 +460,7 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { virtual void WarnAboutSize(); /// Make this synchronized function for unit test. - void PopWorkerCallbackInternal(const TaskSpecification &task_spec, - const PopWorkerCallback &callback, + void PopWorkerCallbackInternal(const PopWorkerCallback &callback, std::shared_ptr worker, PopWorkerStatus status); @@ -504,19 +503,44 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { std::vector dynamic_options; }; - struct TaskWaitingForWorkerInfo { - /// The spec of task. - TaskSpecification task_spec; - /// The callback function which should be called when worker registered. - PopWorkerCallback callback; - }; - - /// Represents a PopWorker call. struct PopWorkerRequest { - TaskSpecification task_spec; + rpc::Language language; + rpc::WorkerType worker_type; + JobID job_id; // can be Nil + ActorID root_detached_actor_id; // can be Nil + std::optional is_gpu; + std::optional is_actor_worker; + rpc::RuntimeEnvInfo runtime_env_info; + int runtime_env_hash; + std::vector dynamic_options; + PopWorkerCallback callback; + + PopWorkerRequest(rpc::Language lang, + rpc::WorkerType worker_type, + JobID job, + ActorID root_actor_id, + std::optional gpu, + std::optional actor_worker, + rpc::RuntimeEnvInfo runtime_env_info, + int runtime_hash, + std::vector options, + PopWorkerCallback callback) + : language(lang), + worker_type(worker_type), + job_id(job), + root_detached_actor_id(root_actor_id), + is_gpu(gpu), + is_actor_worker(actor_worker), + runtime_env_info(std::move(runtime_env_info)), + runtime_env_hash(runtime_hash), + dynamic_options(std::move(options)), + callback(std::move(callback)) {} }; + // Starts a new worker that fulfills `pop_worker_request`. + void StartNewWorker(const std::shared_ptr &pop_worker_request); + /// An internal data structure that maintains the pool state per language. struct State { /// The commands and arguments used to start the worker process @@ -538,10 +562,13 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// the extra information of the process. Note that the shim process PID is the /// same with worker process PID, except worker process in container. absl::flat_hash_map worker_processes; - /// A map for looking up the task by the startup token of starting worker process. - absl::flat_hash_map starting_workers_to_tasks; - /// Pop worker requests that are pending due to maximum_startup_concurrency_. - std::deque pending_pop_worker_requests; + /// FIFO queue of pending requests with workers STARTED but pending registration. + /// If a request stays in this status for >= worker_register_timeout_seconds, we'll + /// fail the request and kill the worker process. + std::deque> pending_registration_requests; + /// FIFO queue of pending requests with workers NOT STARTED due to + /// maximum_startup_concurrency_. + std::deque> pending_start_requests; /// We'll push a warning to the user every time a multiple of this many /// worker processes has been started. int multiple_for_warning; @@ -569,10 +596,17 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// (due to worker process crash or any other reasons), remove them /// from `worker_processes`. Otherwise if we'll mistakenly /// think there are unregistered workers, and won't start new workers. - void MonitorStartingWorkerProcess(const Process &proc, - StartupToken proc_startup_token, + void MonitorStartingWorkerProcess(StartupToken proc_startup_token, const Language &language, - const rpc::WorkerType worker_type); + rpc::WorkerType worker_type); + + /// Start a timer to monitor the starting worker process. + /// Called when a worker process is started and waiting for registration for the + /// request. If the registration is not finished within the timeout, we'll failed the + /// request. Note we don't do anything to the worker process itself, as it's timed out + /// by MonitorStartingWorkerProcess. + void MonitorPopWorkerRequestForRegistration( + std::shared_ptr pop_worker_request); /// Get the next unallocated port in the free ports list. If a port range isn't /// configured, returns 0. @@ -598,10 +632,10 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// worker. void TryStartIOWorkers(const Language &language, const rpc::WorkerType &worker_type); - /// Try to fulfill pending PopWorker requests. + /// Try to fulfill pending_start_requests by trying to start more workers. /// This happens when we have more room to start workers or an idle worker is pushed. /// \param language The language of the PopWorker requests. - void TryPendingPopWorkerRequests(const Language &language); + void TryPendingStartRequests(const Language &language); /// Get either restore or spill worker state from state based on worker_type. /// @@ -626,29 +660,9 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// Call the `PopWorkerCallback` function asynchronously to make sure executed in /// different stack. - virtual void PopWorkerCallbackAsync(const TaskSpecification &task_spec, - const PopWorkerCallback &callback, + virtual void PopWorkerCallbackAsync(PopWorkerCallback callback, std::shared_ptr worker, - PopWorkerStatus status = PopWorkerStatus::OK); - - /// Try to find a task that is associated with the given worker process from the given - /// queue. If found, invoke its PopWorkerCallback. - /// \param workers_to_tasks The queue of tasks which waiting for workers. - /// \param startup_token The startup token representing the worker. - /// \param worker A new idle worker. If the worker is empty, we could also callback - /// to the task. - /// \param status The pop worker status which will be forwarded to - /// `PopWorkerCallback`. - /// \param found Whether the related task found or not. - /// \param worker_used Whether the worker is used by the task, only valid when found is - /// true. - void InvokePopWorkerCallbackForProcess( - absl::flat_hash_map &workers_to_tasks, - StartupToken startup_token, - const std::shared_ptr &worker, - const PopWorkerStatus &status, - bool *found /* output */, - bool *worker_used /* output */); + PopWorkerStatus status); /// We manage all runtime env resources locally by the two methods: /// `GetOrCreateRuntimeEnv` and `DeleteRuntimeEnvIfPossible`. @@ -706,18 +720,18 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { std::pair, ProcessEnvironment> BuildProcessCommandArgs( const Language &language, rpc::JobConfig *job_config, - const rpc::WorkerType worker_type, + rpc::WorkerType worker_type, const JobID &job_id, const std::vector &dynamic_options, - const int runtime_env_hash, + int runtime_env_hash, const std::string &serialized_runtime_env_context, const WorkerPool::State &state) const; void ExecuteOnPrestartWorkersStarted(std::function callback); // If this worker can serve the task. - WorkerUnfitForTaskReason WorkerFitsForTask(const WorkerInterface &worker, - const TaskSpecification &task_spec) const; + WorkerUnfitForTaskReason WorkerFitsForTask( + const WorkerInterface &worker, const PopWorkerRequest &pop_worker_request) const; /// For Process class for managing subprocesses (e.g. reaping zombies). instrumented_io_context *io_service_; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 8cfd03201e34c..d945384b72774 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -44,7 +44,7 @@ std::vector LANGUAGES = {Language::PYTHON, Language::JAVA}; class MockWorkerClient : public rpc::CoreWorkerClientInterface { public: - MockWorkerClient() {} + MockWorkerClient() = default; void Exit(const rpc::ExitRequest &request, const rpc::ClientCallback &callback) { @@ -84,7 +84,7 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { static std::unordered_map runtime_env_reference; -static int GetReferenceCount(const std::string serialized_runtime_env) { +static int GetReferenceCount(const std::string &serialized_runtime_env) { auto it = runtime_env_reference.find(serialized_runtime_env); return it == runtime_env_reference.end() ? 0 : it->second; } @@ -105,7 +105,7 @@ class MockRuntimeEnvAgentClient : public RuntimeEnvAgentClient { } else { runtime_env_reference[serialized_runtime_env] += 1; } - callback(true, "{\"dummy\":\"dummy\"}", ""); + callback(true, R"({"dummy":"dummy"})", ""); } }; @@ -159,17 +159,16 @@ class WorkerPoolMock : public WorkerPool { using WorkerPool::PopWorkerCallbackInternal; // Mock `PopWorkerCallbackAsync` to synchronized function. - void PopWorkerCallbackAsync(const TaskSpecification &task_spec, - const PopWorkerCallback &callback, + void PopWorkerCallbackAsync(PopWorkerCallback callback, std::shared_ptr worker, PopWorkerStatus status = PopWorkerStatus::OK) override { - PopWorkerCallbackInternal(task_spec, callback, worker, status); + PopWorkerCallbackInternal(callback, worker, status); } Process StartProcess(const std::vector &worker_command_args, const ProcessEnvironment &env) override { // Use a bogus process ID that won't conflict with those in the system - pid_t pid = static_cast(PID_MAX_LIMIT + 1 + worker_commands_by_proc_.size()); + auto pid = static_cast(PID_MAX_LIMIT + 1 + worker_commands_by_proc_.size()); last_worker_process_ = Process::FromPid(pid); worker_commands_by_proc_[last_worker_process_] = worker_command_args; startup_tokens_by_proc_[last_worker_process_] = @@ -195,10 +194,18 @@ class WorkerPoolMock : public WorkerPool { return total; } - int NumPendingPopWorkerRequests() const { + int NumPendingStartRequests() const { int total = 0; for (auto &entry : states_by_lang_) { - total += entry.second.pending_pop_worker_requests.size(); + total += entry.second.pending_start_requests.size(); + } + return total; + } + + int NumPendingRegistrationRequests() const { + int total = 0; + for (auto &entry : states_by_lang_) { + total += entry.second.pending_registration_requests.size(); } return total; } @@ -297,7 +304,7 @@ class WorkerPoolMock : public WorkerPool { // Create workers for processes and push them to worker pool. // \param[in] timeout_worker_number Don't register some workers to simulate worker // registration timeout. - void PushWorkers(int timeout_worker_number = 0) { + void PushWorkers(int timeout_worker_number, JobID job_id) { auto processes = GetProcesses(); for (auto it = processes.begin(); it != processes.end(); ++it) { auto pushed_it = pushedProcesses_.find(it->first); @@ -326,7 +333,7 @@ class WorkerPoolMock : public WorkerPool { auto worker = CreateWorker( it->first, is_java ? Language::JAVA : Language::PYTHON, - JOB_ID, + job_id, rpc::WorkerType::WORKER, runtime_env_hash, startup_tokens_by_proc_[it->first], @@ -372,7 +379,7 @@ class WorkerPoolMock : public WorkerPool { return true; }); if (push_workers) { - PushWorkers(timeout_worker_number); + PushWorkers(timeout_worker_number, task_spec.JobId()); } promise.get_future().get(); return popped_worker; @@ -839,7 +846,8 @@ TEST_F(WorkerPoolDriverRegisteredTest, MaximumStartupConcurrency) { started_processes.push_back(last_process); } ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); - ASSERT_EQ(0, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(0, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumPendingRegistrationRequests()); // Can't start a new worker process at this point. worker_pool_->PopWorker( @@ -848,7 +856,8 @@ TEST_F(WorkerPoolDriverRegisteredTest, MaximumStartupConcurrency) { PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { return true; }); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); - ASSERT_EQ(1, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(1, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumPendingRegistrationRequests()); std::vector> workers; // Call `RegisterWorker` to emulate worker registration. @@ -860,7 +869,10 @@ TEST_F(WorkerPoolDriverRegisteredTest, MaximumStartupConcurrency) { })); // Calling `RegisterWorker` won't affect the counter of starting worker processes. ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); - ASSERT_EQ(1, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(1, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, + worker_pool_->NumPendingRegistrationRequests()); + workers.push_back(worker); } @@ -872,7 +884,8 @@ TEST_F(WorkerPoolDriverRegisteredTest, MaximumStartupConcurrency) { PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { return true; }); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); - ASSERT_EQ(2, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(2, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumPendingRegistrationRequests()); // Call `OnWorkerStarted` to emulate worker port announcement. worker_pool_->OnWorkerStarted(workers[0]); @@ -881,7 +894,8 @@ TEST_F(WorkerPoolDriverRegisteredTest, MaximumStartupConcurrency) { // One pending pop worker request now can be fulfilled. ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY + 1, worker_pool_->GetProcessSize()); - ASSERT_EQ(1, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(1, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumPendingRegistrationRequests()); // Can't start a new worker process at this point. worker_pool_->PopWorker( @@ -891,22 +905,28 @@ TEST_F(WorkerPoolDriverRegisteredTest, MaximumStartupConcurrency) { const std::string &runtime_env_setup_error_message) -> bool { return true; }); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY + 1, worker_pool_->GetProcessSize()); - ASSERT_EQ(2, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(2, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumPendingRegistrationRequests()); // Return a worker. worker_pool_->PushWorker(workers[0]); - // One more pending pop worker request can be fulfilled. + // The pushed worker fulfills a pending registration request, not a pending start + // request. ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY + 1, worker_pool_->GetProcessSize()); - ASSERT_EQ(1, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(2, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY - 1, + worker_pool_->NumPendingRegistrationRequests()); + ASSERT_EQ(0, worker_pool_->GetIdleWorkerSize()); // Disconnect a worker. worker_pool_->DisconnectWorker(workers[1], rpc::WorkerExitType::SYSTEM_ERROR); - // One more pending pop worker request can be fulfilled. + // We have 1 more slot to start a new worker process. ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumWorkersStarting()); ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY + 2, worker_pool_->GetProcessSize()); - ASSERT_EQ(0, worker_pool_->NumPendingPopWorkerRequests()); + ASSERT_EQ(1, worker_pool_->NumPendingStartRequests()); + ASSERT_EQ(MAXIMUM_STARTUP_CONCURRENCY, worker_pool_->NumPendingRegistrationRequests()); ASSERT_EQ(0, worker_pool_->GetIdleWorkerSize()); worker_pool_->ClearProcesses(); @@ -1189,6 +1209,11 @@ TEST_F(WorkerPoolDriverRegisteredTest, TestWorkerCapping) { auto task_spec = ExampleTaskSpec(/*actor_id=*/ActorID::Nil(), Language::PYTHON, job_id); auto worker = worker_pool_->PopWorkerSync(task_spec, false); + // Simulate running the task and finish. This is to set task_assign_time_. + RayTask task(task_spec); + worker->SetAssignedTask(task); + worker->AssignTaskId(TaskID::Nil()); + popped_workers.push_back(worker); ASSERT_TRUE(worker); ASSERT_EQ(worker->GetAssignedJobId(), job_id); @@ -1217,7 +1242,8 @@ TEST_F(WorkerPoolDriverRegisteredTest, TestWorkerCapping) { // Since the idle workers are killed in FIFO, we can assume the first entry in the idle // workers will be killed. auto mock_rpc_client_it = mock_worker_rpc_clients_.find(popped_workers[0]->WorkerId()); - ASSERT_EQ(mock_rpc_client_it->second->exit_count, 1); + ASSERT_EQ(mock_rpc_client_it->second->exit_count, 1) + << " expected pid " << popped_workers[0]->GetProcess().GetId(); ASSERT_EQ(mock_rpc_client_it->second->last_exit_forced, false); mock_rpc_client_it->second->ExitReplySucceed(); worker_pool_->TryKillingIdleWorkers(); @@ -1420,19 +1446,22 @@ TEST_F(WorkerPoolDriverRegisteredTest, TestJobFinishedForPopWorker) { task_spec = ExampleTaskSpec(/*actor_id=*/ActorID::Nil(), Language::PYTHON, job_id); pop_worker_status = PopWorkerStatus::OK; // This will start a new worker. + std::promise promise; worker_pool_->PopWorker( task_spec, [&](const std::shared_ptr worker, PopWorkerStatus status, const std::string &runtime_env_setup_error_message) -> bool { pop_worker_status = status; + promise.set_value(true); return false; }); auto process = worker_pool_->LastStartedWorkerProcess(); RAY_CHECK(process.IsValid()); ASSERT_EQ(1, worker_pool_->NumWorkersStarting()); - worker = worker_pool_->CreateWorker(Process()); + // Starts a worker for JOB_ID2. + worker = worker_pool_->CreateWorker(Process(), Language::PYTHON, job_id); worker->SetStartupToken(worker_pool_->GetStartupToken(process)); RAY_CHECK_OK(worker_pool_->RegisterWorker( worker, process.GetId(), worker_pool_->GetStartupToken(process), [](Status, int) { @@ -1446,8 +1475,10 @@ TEST_F(WorkerPoolDriverRegisteredTest, TestJobFinishedForPopWorker) { // Finish the job. worker_pool_->HandleJobFinished(job_id); - // This will trigger the PopWorker callback. + // This will trigger the PopWorker callback in async. worker_pool_->PushWorker(worker); + promise.get_future().get(); + ASSERT_EQ(pop_worker_status, PopWorkerStatus::JobFinished); ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); @@ -1807,7 +1838,7 @@ TEST_F(WorkerPoolDriverRegisteredTest, WorkerNoLeaks) { // No idle workers because no workers pushed. ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 0); // push workers. - worker_pool_->PushWorkers(); + worker_pool_->PushWorkers(0, task_spec.JobId()); // The worker has been pushed but not dispatched. ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); // Pop a worker and don't dispatch. @@ -2012,7 +2043,7 @@ TEST_F(WorkerPoolDriverRegisteredTest, WorkerReuseForPrestartedWorker) { const auto task_spec = ExampleTaskSpec(); worker_pool_->PrestartDefaultCpuWorkers(ray::Language::PYTHON, 1); - worker_pool_->PushWorkers(); + worker_pool_->PushWorkers(0, task_spec.JobId()); // One worker process has been prestarted. ASSERT_EQ(worker_pool_->GetProcessSize(), 1); ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index 9037613968ceb..d86475fbdf6b5 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -23,6 +23,7 @@ #include "ray/common/status.h" #include "ray/rpc/client_call.h" #include "ray/rpc/common.h" +#include "ray/rpc/rpc_chaos.h" namespace ray { namespace rpc { @@ -148,15 +149,43 @@ class GrpcClient { const ClientCallback &callback, std::string call_name = "UNKNOWN_RPC", int64_t method_timeout_ms = -1) { - auto call = client_call_manager_.CreateCall( - *stub_, - prepare_async_function, - request, - callback, - std::move(call_name), - method_timeout_ms); - RAY_CHECK(call != nullptr); - call_method_invoked_ = true; + testing::RpcFailure failure = testing::get_rpc_failure(call_name); + if (failure == testing::RpcFailure::Request) { + // Simulate the case where the PRC fails before server receives + // the request. + RAY_LOG(INFO) << "Inject RPC request failure for " << call_name; + client_call_manager_.GetMainService().post( + [callback]() { + callback(Status::RpcError("Unavailable", grpc::StatusCode::UNAVAILABLE), + Reply()); + }, + "RpcChaos"); + } else if (failure == testing::RpcFailure::Response) { + // Simulate the case where the RPC fails after server sends + // the response. + RAY_LOG(INFO) << "Inject RPC response failure for " << call_name; + client_call_manager_.CreateCall( + *stub_, + prepare_async_function, + request, + [callback](const Status &status, Reply &&reply) { + callback(Status::RpcError("Unavailable", grpc::StatusCode::UNAVAILABLE), + Reply()); + }, + std::move(call_name), + method_timeout_ms); + } else { + auto call = client_call_manager_.CreateCall( + *stub_, + prepare_async_function, + request, + callback, + std::move(call_name), + method_timeout_ms); + RAY_CHECK(call != nullptr); + } + + call_method_invoked_.store(true); } std::shared_ptr Channel() const { return channel_; } @@ -167,7 +196,8 @@ class GrpcClient { /// Also see https://grpc.github.io/grpc/core/md_doc_connectivity-semantics-and-api.html /// for channel connectivity state machine. bool IsChannelIdleAfterRPCs() const { - return (channel_->GetState(false) == GRPC_CHANNEL_IDLE) && call_method_invoked_; + return (channel_->GetState(false) == GRPC_CHANNEL_IDLE) && + call_method_invoked_.load(); } private: @@ -179,7 +209,7 @@ class GrpcClient { /// The channel of the stub. std::shared_ptr channel_; /// Whether CallMethod is invoked. - bool call_method_invoked_ = false; + std::atomic call_method_invoked_ = false; }; } // namespace rpc diff --git a/src/ray/rpc/rpc_chaos.cc b/src/ray/rpc/rpc_chaos.cc new file mode 100644 index 0000000000000..373e3a9be60f0 --- /dev/null +++ b/src/ray/rpc/rpc_chaos.cc @@ -0,0 +1,109 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/rpc_chaos.h" + +#include +#include + +#include "absl/synchronization/mutex.h" +#include "ray/common/ray_config.h" + +namespace ray { +namespace rpc { +namespace testing { +namespace { + +/* + RpcFailureManager is a simple chaos testing framework. Before starting ray, users + should set up os environment to use this feature for testing purposes. + To use this, simply do + export RAY_testing_rpc_failure="method1=3,method2=5" + Key is the RPC call name and value is the max number of failures to inject. +*/ +class RpcFailureManager { + public: + RpcFailureManager() { Init(); } + + void Init() { + absl::MutexLock lock(&mu_); + + failable_methods_.clear(); + + if (!RayConfig::instance().testing_rpc_failure().empty()) { + for (const auto &item : + absl::StrSplit(RayConfig::instance().testing_rpc_failure(), ",")) { + std::vector parts = absl::StrSplit(item, "="); + RAY_CHECK_EQ(parts.size(), 2UL); + failable_methods_.emplace(parts[0], std::atoi(parts[1].c_str())); + } + + std::random_device rd; + auto seed = rd(); + RAY_LOG(INFO) << "Setting RpcFailureManager seed to " << seed; + gen_.seed(seed); + } + } + + RpcFailure GetRpcFailure(const std::string &name) { + absl::MutexLock lock(&mu_); + + if (failable_methods_.find(name) == failable_methods_.end()) { + return RpcFailure::None; + } + + uint64_t &num_remaining_failures = failable_methods_.at(name); + if (num_remaining_failures == 0) { + return RpcFailure::None; + } + + std::uniform_int_distribution dist(0, 3); + int rand = dist(gen_); + if (rand == 0) { + // 25% chance + num_remaining_failures--; + return RpcFailure::Request; + } else if (rand == 1) { + // 25% chance + num_remaining_failures--; + return RpcFailure::Response; + } else { + // 50% chance + return RpcFailure::None; + } + } + + private: + absl::Mutex mu_; + std::mt19937 gen_; + // call name -> # remaining failures + std::unordered_map failable_methods_ ABSL_GUARDED_BY(&mu_); +}; + +static RpcFailureManager _rpc_failure_manager; + +} // namespace + +RpcFailure get_rpc_failure(const std::string &name) { + if (RayConfig::instance().testing_rpc_failure().empty()) { + return RpcFailure::None; + } + return _rpc_failure_manager.GetRpcFailure(name); +} + +void init() { _rpc_failure_manager.Init(); } + +} // namespace testing +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/rpc_chaos.h b/src/ray/rpc/rpc_chaos.h new file mode 100644 index 0000000000000..cb0e614eead9f --- /dev/null +++ b/src/ray/rpc/rpc_chaos.h @@ -0,0 +1,37 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ray { +namespace rpc { +namespace testing { + +enum class RpcFailure { + None, + // Failure before server receives the request + Request, + // Failure after server sends the response + Response, +}; + +RpcFailure get_rpc_failure(const std::string &name); + +void init(); + +} // namespace testing +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/test/rpc_chaos_test.cc b/src/ray/rpc/test/rpc_chaos_test.cc new file mode 100644 index 0000000000000..75bced2592537 --- /dev/null +++ b/src/ray/rpc/test/rpc_chaos_test.cc @@ -0,0 +1,34 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/rpc_chaos.h" + +#include + +#include "gtest/gtest.h" +#include "ray/common/ray_config.h" + +TEST(RpcChaosTest, Basic) { + RayConfig::instance().testing_rpc_failure() = "method1=0,method2=1"; + ray::rpc::testing::init(); + ASSERT_EQ(ray::rpc::testing::get_rpc_failure("unknown"), + ray::rpc::testing::RpcFailure::None); + ASSERT_EQ(ray::rpc::testing::get_rpc_failure("method1"), + ray::rpc::testing::RpcFailure::None); + // At most one failure. + ASSERT_FALSE(ray::rpc::testing::get_rpc_failure("method2") != + ray::rpc::testing::RpcFailure::None && + ray::rpc::testing::get_rpc_failure("method2") != + ray::rpc::testing::RpcFailure::None); +} diff --git a/src/ray/stats/metric_defs.cc b/src/ray/stats/metric_defs.cc index 114e6c07434d5..5d393acdce8d5 100644 --- a/src/ray/stats/metric_defs.cc +++ b/src/ray/stats/metric_defs.cc @@ -67,6 +67,21 @@ DEFINE_stats(actors, (), ray::stats::GAUGE); +/// Job related stats. +DEFINE_stats(running_jobs, + "Number of jobs currently running.", + /*tags=*/(), + /*buckets=*/(), + ray::stats::GAUGE); + +DEFINE_stats(finished_jobs, + "Number of jobs finished.", + // TODO(hjiang): Consider adding task completion status, for example, failed, + // completed in tags. + /*tags=*/(), + /*buckets=*/(), + ray::stats::COUNT); + /// Logical resource usage reported by raylets. DEFINE_stats(resources, // TODO(sang): Support placement_group_reserved_available | used @@ -146,8 +161,8 @@ DEFINE_stats(operation_active_count, DEFINE_stats(grpc_server_req_process_time_ms, "Request latency in grpc server", ("Method"), - (), - ray::stats::GAUGE); + ({0.1, 1, 10, 100, 1000, 10000}, ), + ray::stats::HISTOGRAM); DEFINE_stats(grpc_server_req_new, "New request number in grpc server", ("Method"), diff --git a/src/ray/stats/metric_defs.h b/src/ray/stats/metric_defs.h index 44d77b8171594..d76c64e7f42f0 100644 --- a/src/ray/stats/metric_defs.h +++ b/src/ray/stats/metric_defs.h @@ -48,6 +48,10 @@ DECLARE_stats(tasks); /// Actor stats, broken down by state. DECLARE_stats(actors); +/// Job stats. +DECLARE_stats(running_jobs); +DECLARE_stats(finished_jobs); + /// Placement group stats, broken down by state. DECLARE_stats(placement_groups); diff --git a/src/ray/util/BUILD b/src/ray/util/BUILD index b17521a4cdd68..bd35c3874218f 100644 --- a/src/ray/util/BUILD +++ b/src/ray/util/BUILD @@ -41,3 +41,10 @@ cc_library( "@nlohmann_json", ], ) + +cc_library( + name = "thread_checker", + hdrs = ["thread_checker.h"], + srcs = ["thread_checker.cc"], + visibility = ["//visibility:public"], +) diff --git a/src/ray/util/container_util.h b/src/ray/util/container_util.h index 1686c1459e250..d46da4d580d63 100644 --- a/src/ray/util/container_util.h +++ b/src/ray/util/container_util.h @@ -14,17 +14,20 @@ #pragma once +#include #include #include #include #include #include #include +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "ray/util/logging.h" namespace ray { @@ -67,6 +70,13 @@ std::ostream &operator<<(std::ostream &os, DebugStringWrapper wrapper) { return os << wrapper.obj_; } +// TODO(hjiang): Implement debug string for `std::variant`. +template <> +inline std::ostream &operator<<(std::ostream &os, + DebugStringWrapper wrapper) { + return os << "(nullopt)"; +} + template std::ostream &operator<<(std::ostream &os, DebugStringWrapper> pair) { return os << "(" << debug_string(pair.obj_.first) << ", " @@ -93,6 +103,10 @@ std::ostream &operator<<(std::ostream &os, DebugStringWrapper> return os; } +template +std::ostream &operator<<(std::ostream &os, DebugStringWrapper> c) { + return c.StringifyContainer(os); +} template std::ostream &operator<<(std::ostream &os, DebugStringWrapper> c) { return c.StringifyContainer(os); @@ -120,6 +134,19 @@ std::ostream &operator<<(std::ostream &os, DebugStringWrapper> c) { return c.StringifyContainer(os); } +template +std::ostream &operator<<(std::ostream &os, + DebugStringWrapper> c) { + return c.StringifyContainer(os); +} + +template +std::ostream &operator<<(std::ostream &os, DebugStringWrapper> c) { + if (!c.obj_.has_value()) { + return os << debug_string(std::nullopt); + } + return os << debug_string(c.obj_.value()); +} template const typename C::mapped_type &map_find_or_die(const C &c, diff --git a/src/ray/util/counter_map.h b/src/ray/util/counter_map.h index 6c7acbe046b8f..acd442b0e2c71 100644 --- a/src/ray/util/counter_map.h +++ b/src/ray/util/counter_map.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -35,7 +36,7 @@ template class CounterMap { public: - CounterMap(){}; + CounterMap() = default; CounterMap(const CounterMap &other) = delete; @@ -45,7 +46,7 @@ class CounterMap { /// Changes are buffered until `FlushOnChangeCallbacks()` is called to enable /// batching for performance reasons. void SetOnChangeCallback(std::function on_change) { - on_change_ = on_change; + on_change_ = std::move(on_change); } /// Flush any pending on change callbacks. diff --git a/src/ray/util/tests/BUILD b/src/ray/util/tests/BUILD index 29cf17f706f7d..096b090f18542 100644 --- a/src/ray/util/tests/BUILD +++ b/src/ray/util/tests/BUILD @@ -1,12 +1,25 @@ load("@rules_cc//cc:defs.bzl", "cc_test") load("//bazel:ray.bzl", "COPTS") +cc_test( + name = "thread_checker_test", + srcs = ["thread_checker_test.cc"], + copts = COPTS, + size = "small", + tags = ["team:core"], + deps = [ + "//src/ray/util:thread_checker", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "container_util_test", size = "small", srcs = ["container_util_test.cc"], copts = COPTS, tags = ["team:core"], + linkstatic = True, deps = [ "//src/ray/util", "@com_google_absl//absl/container:flat_hash_map", diff --git a/src/ray/util/tests/container_util_test.cc b/src/ray/util/tests/container_util_test.cc index 0a34e91647a7b..590246fd2afd0 100644 --- a/src/ray/util/tests/container_util_test.cc +++ b/src/ray/util/tests/container_util_test.cc @@ -14,11 +14,13 @@ #include "ray/util/container_util.h" +#include + +#include #include +#include #include -#include "gtest/gtest.h" - namespace ray { template @@ -29,8 +31,18 @@ std::string debug_string_to_string(const T &t) { } TEST(ContainerUtilTest, TestDebugString) { + // Numerical values. ASSERT_EQ(debug_string_to_string(static_cast(2)), "2"); + + // String values. + ASSERT_EQ(debug_string_to_string(std::string_view{"hello"}), "hello"); + ASSERT_EQ(debug_string_to_string(std::string{"hello"}), "hello"); + + // Non-associative containers. ASSERT_EQ(debug_string_to_string(std::vector{1, 2}), "[1, 2]"); + ASSERT_EQ(debug_string_to_string(std::array{1, 2, 3}), "[1, 2, 3]"); + + // Associative containers. ASSERT_EQ(debug_string_to_string(std::set{1, 2}), "[1, 2]"); ASSERT_EQ(debug_string_to_string(std::unordered_set{2}), "[2]"); ASSERT_EQ(debug_string_to_string(absl::flat_hash_set{1}), "[1]"); @@ -54,6 +66,11 @@ TEST(ContainerUtilTest, TestDebugString) { ASSERT_EQ(debug_string_to_string(std::pair{3, "value"}), "(3, value)"); + // Optional. + ASSERT_EQ(debug_string_to_string(std::nullopt), "(nullopt)"); + ASSERT_EQ(debug_string_to_string(std::optional{}), "(nullopt)"); + ASSERT_EQ(debug_string_to_string(std::optional{"hello"}), "hello"); + // Composable: tuples of pairs of maps and vectors. ASSERT_EQ(debug_string_to_string( std::tuple>, std::map>{ diff --git a/src/ray/util/tests/thread_checker_test.cc b/src/ray/util/tests/thread_checker_test.cc new file mode 100644 index 0000000000000..08bd8588ee7fa --- /dev/null +++ b/src/ray/util/tests/thread_checker_test.cc @@ -0,0 +1,34 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ray/util/thread_checker.h" + +#include + +#include + +namespace ray { + +TEST(ThreadCheckerTest, BasicTest) { + ThreadChecker thread_checker; + // Pass at initialization. + ASSERT_TRUE(thread_checker.IsOnSameThread()); + // Pass when invoked at the same thread. + ASSERT_TRUE(thread_checker.IsOnSameThread()); + + auto thd = std::thread([&]() { ASSERT_FALSE(thread_checker.IsOnSameThread()); }); + thd.join(); +} + +} // namespace ray diff --git a/src/ray/util/thread_checker.cc b/src/ray/util/thread_checker.cc new file mode 100644 index 0000000000000..73a0072c75752 --- /dev/null +++ b/src/ray/util/thread_checker.cc @@ -0,0 +1,26 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ray/util/thread_checker.h" + +namespace ray { + +bool ThreadChecker::IsOnSameThread() { + const auto cur_id = std::this_thread::get_id(); + std::thread::id uninitialized_id; + return thread_id_.compare_exchange_strong(uninitialized_id, cur_id) || + (uninitialized_id == cur_id); +} + +} // namespace ray diff --git a/src/ray/util/thread_checker.h b/src/ray/util/thread_checker.h new file mode 100644 index 0000000000000..622624859b753 --- /dev/null +++ b/src/ray/util/thread_checker.h @@ -0,0 +1,43 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Used to sanity check threading issues by checking current thread id. +// +// Example usage: +// ThreadChecker thread_checker{}; +// +// // Initialize on the thread at first usage. +// RAY_CHECK(thread_checker.ok()); +// +// // Check it's on the same thread. +// RAY_CHECK(thread_checker.ok()); + +#pragma once + +#include +#include + +namespace ray { + +class ThreadChecker { + public: + // Return true at initialization, or current invocation happens on the same thread as + // initialization. + bool IsOnSameThread(); + + private: + std::atomic thread_id_{}; +}; + +} // namespace ray