From adbc1141ae5bc45390bc49a96fee9638f2e33326 Mon Sep 17 00:00:00 2001 From: Chris Rawles Date: Wed, 15 May 2024 14:19:42 -0700 Subject: [PATCH] Add an a11y wrapper, which allows a11y tree retrieval. PiperOrigin-RevId: 634073719 --- android_env/components/a11y/__init__.py | 15 + android_env/components/a11y/a11y_events.py | 118 +++ .../components/a11y/a11y_events_test.py | 173 ++++ android_env/components/a11y/a11y_forests.py | 128 +++ .../components/a11y/a11y_forests_test.py | 237 +++++ android_env/components/a11y/a11y_servicer.py | 199 ++++ .../components/a11y/a11y_servicer_test.py | 224 +++++ android_env/wrappers/a11y_grpc_wrapper.py | 505 +++++++++++ .../wrappers/a11y_grpc_wrapper_test.py | 851 ++++++++++++++++++ 9 files changed, 2450 insertions(+) create mode 100644 android_env/components/a11y/__init__.py create mode 100644 android_env/components/a11y/a11y_events.py create mode 100644 android_env/components/a11y/a11y_events_test.py create mode 100644 android_env/components/a11y/a11y_forests.py create mode 100644 android_env/components/a11y/a11y_forests_test.py create mode 100644 android_env/components/a11y/a11y_servicer.py create mode 100644 android_env/components/a11y/a11y_servicer_test.py create mode 100644 android_env/wrappers/a11y_grpc_wrapper.py create mode 100644 android_env/wrappers/a11y_grpc_wrapper_test.py diff --git a/android_env/components/a11y/__init__.py b/android_env/components/a11y/__init__.py new file mode 100644 index 0000000..2f66bf7 --- /dev/null +++ b/android_env/components/a11y/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/android_env/components/a11y/a11y_events.py b/android_env/components/a11y/a11y_events.py new file mode 100644 index 0000000..2df3e85 --- /dev/null +++ b/android_env/components/a11y/a11y_events.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for accessing accessibility events.""" + +from collections.abc import Mapping +from typing import Any + +from absl import logging +from android_env.proto.a11y import a11y_pb2 +import numpy as np + +from google.protobuf import any_pb2 + + +_A11Y_EVENT_KEY = 'full_event' + + +def package_events_to_task_extras( + events: list[a11y_pb2.EventRequest], +) -> Mapping[str, np.ndarray]: + if not events: + return {} + events = np.stack(events, axis=0) + return {_A11Y_EVENT_KEY: events} + + +def extract_events_from_task_extras( + task_extras: Mapping[str, Any] | None = None, +) -> list[Mapping[str, str]]: + """Inspects task_extras and extracts all accessibility events detected. + + Args: + task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a + key in task_extras, then this function returns an empty string. Otherwise, + full_event is expected to be list to be a numpy array with one dimension, + and contains a list of dictionary describing accessibility events that are + present in the given task extras. e.g. 'event_type: + TYPE_WINDOW_CONTENT_CHANGED // event_package_name: + com.google.android.deskclock // source_class_name: + android.widget.ImageView'. + + Returns: + List of all events detected + """ + if task_extras is None or _A11Y_EVENT_KEY not in task_extras: + return [] + + if ( + not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray) + or task_extras[_A11Y_EVENT_KEY].ndim != 1 + ): + raise ValueError( + f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one' + ' dimension.' + ) + + if task_extras[_A11Y_EVENT_KEY].size == 0: + return [] + + events = [] + for e in task_extras[_A11Y_EVENT_KEY]: + if isinstance(e, a11y_pb2.EventRequest): + events.append(dict(e.event)) + elif isinstance(e, dict): + events.append(e) + logging.warning( + 'The event should come only from the a11y_grpc_wrapper. ' + 'Please verify that the upacking operation has not been ' + 'called twice. See here for full task_extras: %s', + task_extras, + ) + elif isinstance(e, any_pb2.Any): + ev = a11y_pb2.EventRequest() + new_any = any_pb2.Any() + new_any.CopyFrom(e) + new_any.Unpack(ev) + events.append(dict(ev.event)) + + else: + raise TypeError( + f'Unexpected event type: {type(e)}. See here for full ' + f'task_extras: {task_extras}.' + ) + + return events + + +def keep_latest_event_only(task_extras: dict[str, Any]): + """Removes all a11y events except the last one observed.""" + if task_extras is None or 'full_event' not in task_extras: + return + + if ( + not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray) + or task_extras[_A11Y_EVENT_KEY].ndim != 1 + ): + raise ValueError( + f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one' + ' dimension.' + ) + + if task_extras[_A11Y_EVENT_KEY].size == 0: + return [] + + task_extras[_A11Y_EVENT_KEY] = task_extras[_A11Y_EVENT_KEY][-1:] diff --git a/android_env/components/a11y/a11y_events_test.py b/android_env/components/a11y/a11y_events_test.py new file mode 100644 index 0000000..3b0d036 --- /dev/null +++ b/android_env/components/a11y/a11y_events_test.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for a11y_events.""" + +from absl.testing import absltest +from absl.testing import parameterized +from android_env.components.a11y import a11y_events +from android_env.proto.a11y import a11y_pb2 +import numpy as np + +from google.protobuf import any_pb2 + + +def _event_request(d: dict[str, str]) -> a11y_pb2.EventRequest: + event_request = a11y_pb2.EventRequest() + for k, v in d.items(): + event_request.event[k] = v + return event_request + + +def _event_request_as_any(d: dict[str, str]) -> any_pb2.Any: + event_request = _event_request(d) + response = any_pb2.Any() + response.Pack(event_request) + return response + + +class A11yEventsTest(parameterized.TestCase): + + @parameterized.parameters( + dict(task_extras={}), + dict( + task_extras={'no_full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}]}, + ), + dict( + task_extras={'full_event': np.array([])}, + ), + dict( + task_extras={}, + ), + ) + def test_no_events_in_task_extras(self, task_extras): + events = a11y_events.extract_events_from_task_extras(task_extras) + self.assertEmpty(events) + + @parameterized.parameters( + dict( + task_extras={'full_event': [{'1': '1'}, {'2': '2'}]}, + expected_events=[{'1': '1'}, {'2': '2'}], + ), + dict( + task_extras={'full_event': [{}]}, + expected_events=[{}], + ), + dict( + task_extras={ + 'full_event_wrong_key': [1, 2, 3], + 'full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}], + }, + expected_events=[{'1': '1'}, {'2': '2'}, {'3': '3'}], + ), + ) + def test_task_extras(self, task_extras, expected_events): + event_requests = [_event_request(e) for e in task_extras['full_event']] + task_extras['full_event'] = np.stack(event_requests, axis=0) + events = a11y_events.extract_events_from_task_extras(task_extras) + self.assertEqual(len(events), len(expected_events)) + for i, event in enumerate(expected_events): + self.assertEqual(len(event), len(expected_events[i])) + for k, v in event.items(): + self.assertIn(k, expected_events[i]) + self.assertEqual(v, expected_events[i][k]) + + def test_events_key_has_dict_event_requrests(self): + event_requests = [ + _event_request({'1': '1'}), + {'2': '2'}, + _event_request({'3': '3'}), + ] + expected_events = [ + {'1': '1'}, + {'2': '2'}, + {'3': '3'}, + ] + task_extras = {'full_event': np.stack(event_requests, axis=0)} + events = a11y_events.extract_events_from_task_extras(task_extras) + self.assertEqual(len(events), len(expected_events)) + for i, event in enumerate(expected_events): + self.assertEqual(len(event), len(expected_events[i])) + for k, v in event.items(): + self.assertIn(k, expected_events[i]) + self.assertEqual(v, expected_events[i][k]) + + def test_events_key_has__event_requrests_packed_as_any(self): + event_requests = [ + _event_request_as_any({'1': '1'}), + {'2': '2'}, + _event_request_as_any({'3': '3'}), + ] + expected_events = [ + {'1': '1'}, + {'2': '2'}, + {'3': '3'}, + ] + task_extras = {'full_event': np.stack(event_requests, axis=0)} + events = a11y_events.extract_events_from_task_extras(task_extras) + self.assertEqual(len(events), len(expected_events)) + for i, event in enumerate(expected_events): + self.assertEqual(len(event), len(expected_events[i])) + for k, v in event.items(): + self.assertIn(k, expected_events[i]) + self.assertEqual(v, expected_events[i][k]) + + def test_events_key_has_non_event_requrests(self): + event_requests = [ + _event_request({'1': '1'}), + 3, # Not an even and not a dict. + _event_request({'3': '3'}), + ] + task_extras = {'full_event': np.stack(event_requests, axis=0)} + with self.assertRaises(TypeError): + _ = a11y_events.extract_events_from_task_extras(task_extras) + + @parameterized.parameters( + dict(task_extras={}, expected_extras={}), + dict( + task_extras={ + 'no_full_event': 42, + }, + expected_extras={ + 'no_full_event': 42, + }, + ), + dict( + task_extras={'full_event': np.array([1, 2]), 'no_full_event': 43}, + expected_extras={'full_event': np.array([2]), 'no_full_event': 43}, + ), + dict( + task_extras={'full_event': np.array([1, 2, 3])}, + expected_extras={'full_event': np.array([3])}, + ), + dict( + task_extras={'full_event': np.array([]), 'no_full_event': 44}, + expected_extras={'full_event': np.array([]), 'no_full_event': 44}, + ), + ) + def test_keep_latest_only(self, task_extras, expected_extras): + a11y_events.keep_latest_event_only(task_extras) + self.assertEqual(len(task_extras), len(expected_extras)) + for k, v in task_extras.items(): + self.assertIn(k, expected_extras) + if k == 'full_event': + np.testing.assert_array_equal(v, expected_extras['full_event']) + else: + self.assertEqual(v, expected_extras[k]) + pass + + +if __name__ == '__main__': + absltest.main() diff --git a/android_env/components/a11y/a11y_forests.py b/android_env/components/a11y/a11y_forests.py new file mode 100644 index 0000000..1cd8ef2 --- /dev/null +++ b/android_env/components/a11y/a11y_forests.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for accessing accessibility events.""" + +from collections.abc import Mapping +from typing import Any + +from android_env.proto.a11y import android_accessibility_forest_pb2 +import numpy as np + +from google.protobuf import any_pb2 + + +_A11Y_FORESTS_KEY = 'accessibility_tree' + + +def package_forests_to_task_extras( + forests: list[android_accessibility_forest_pb2.AndroidAccessibilityForest], +) -> Mapping[str, np.ndarray]: + if not forests: + return {} + forests = np.stack(forests, axis=0) + return {_A11Y_FORESTS_KEY: forests} + + +def task_extras_has_forests(task_extras: Mapping[str, Any]) -> bool: + """Checks that the task_extras has any a11y forest information.""" + if _A11Y_FORESTS_KEY not in task_extras: + return False + + payload = task_extras[_A11Y_FORESTS_KEY] + if not isinstance(payload, np.ndarray) or payload.ndim != 1: + raise ValueError( + f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one' + f' dimension. payload: {payload}' + ) + + if payload.size == 0: + return False + + if any(isinstance(f, any_pb2.Any) for f in payload): + # Forests were packed as Any. + return True + + return any( + isinstance(f, android_accessibility_forest_pb2.AndroidAccessibilityForest) + for f in payload + ) + + +def convert_to_forest( + forest: android_accessibility_forest_pb2.AndroidAccessibilityForest + | any_pb2.Any + | None, +) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None: + """Takes an object and attempts to convert it to a forest.""" + if forest is None: + return None + + if isinstance(forest, any_pb2.Any): + output = android_accessibility_forest_pb2.AndroidAccessibilityForest() + new_any = any_pb2.Any() + new_any.CopyFrom(forest) + new_any.Unpack(output) + return output + elif isinstance( + forest, android_accessibility_forest_pb2.AndroidAccessibilityForest + ): + return forest + else: + return None + + +def extract_forests_from_task_extras( + task_extras: Mapping[str, Any] | None = None, +) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]: + """Inspects task_extras and extracts all accessibility forests detected. + + Args: + task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a + key in task_extras, then this function returns an empty string. Otherwise, + full_event is expected to be list to be a numpy array with one dimension, + and contains a list of dictionary describing accessibility forests that + are present in the given task extras. + + Returns: + List of all forests detected + """ + if task_extras is None or not task_extras_has_forests(task_extras): + return [] + + forests = [] + for f in task_extras[_A11Y_FORESTS_KEY]: + f = convert_to_forest(f) + if f is not None: + forests.append(f) + return forests + + +def keep_latest_forest_only(task_extras: dict[str, Any]): + """Removes all a11y forests except the last one observed.""" + if _A11Y_FORESTS_KEY not in task_extras.keys(): + return + + payload = task_extras[_A11Y_FORESTS_KEY] + if not isinstance(payload, np.ndarray) or payload.ndim != 1: + raise ValueError( + f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one' + f' dimension. payload: {payload}' + ) + + if payload.size == 0: + return + + task_extras[_A11Y_FORESTS_KEY] = payload[-1:] diff --git a/android_env/components/a11y/a11y_forests_test.py b/android_env/components/a11y/a11y_forests_test.py new file mode 100644 index 0000000..1a4d477 --- /dev/null +++ b/android_env/components/a11y/a11y_forests_test.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for a11y_forests.""" + +from absl.testing import absltest +from absl.testing import parameterized +from android_env.components.a11y import a11y_forests +from android_env.proto.a11y import android_accessibility_forest_pb2 +import numpy as np + +from google.protobuf import any_pb2 + + +def _pack_any(proto_message) -> any_pb2.Any: + response = any_pb2.Any() + response.Pack(proto_message) + return response + + +def _empty_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + return android_accessibility_forest_pb2.AndroidAccessibilityForest() + + +def _one_empty_window_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + forest.windows.add() + return forest + + +def _two_window_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + window = forest.windows.add() + window.tree.nodes.add( + class_name='foo', is_clickable=True, hint_text='Foo hint' + ) + forest.windows.add() + return forest + + +class A11YForestsTest(parameterized.TestCase): + + @parameterized.parameters( + dict(task_extras={}, expected_forests=[], convert_to_np=[]), + dict( + task_extras={'accessibility_tree': []}, + convert_to_np=['accessibility_tree'], + expected_forests=[], + ), + dict( + task_extras={ + 'not_accessibility_tree': [ + _empty_forest(), + _one_empty_window_forest(), + _two_window_forest(), + ], + }, + convert_to_np=['not_accessibility_tree'], + expected_forests=[], + ), + dict( + task_extras={ + 'accessibility_tree': [ + _empty_forest(), + {'not_a_forest_key': 'nor_a_forest_value'}, + _two_window_forest(), + ] + }, + convert_to_np=['accessibility_tree'], + expected_forests=[_empty_forest(), _two_window_forest()], + ), + dict( + task_extras={ + 'accessibility_tree': [ + {'not_a_forest_key': 'nor_a_forest_value'}, + 3, + 4, + {'not_a_forest_key': _empty_forest()}, + ], + }, + convert_to_np=['accessibility_tree'], + expected_forests=[], + ), + dict( + task_extras={'accessibility_tree': []}, + convert_to_np=['accessibility_tree'], + expected_forests=[], + ), + dict( + task_extras={ + 'accessibility_tree_wrong_key': [1, 2, 3], + 'accessibility_tree': [ + _empty_forest(), + None, + None, + _one_empty_window_forest(), + _two_window_forest(), + ], + }, + convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'], + expected_forests=[ + _empty_forest(), + _one_empty_window_forest(), + _two_window_forest(), + ], + ), + dict( + task_extras={ + 'accessibility_tree_wrong_key': [1, 2, 3], + 'accessibility_tree': [ + None, + _pack_any(_empty_forest()), + _pack_any(_one_empty_window_forest()), + _pack_any(_two_window_forest()), + ], + }, + convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'], + expected_forests=[ + _empty_forest(), + _one_empty_window_forest(), + _two_window_forest(), + ], + ), + dict( + task_extras={ + 'accessibility_tree': [ + _pack_any(_empty_forest()), + {'not_a_forest_key': 'nor_a_forest_value'}, + None, + _two_window_forest(), + None, + ] + }, + convert_to_np=['accessibility_tree'], + expected_forests=[_empty_forest(), _two_window_forest()], + ), + ) + def test_task_extras(self, task_extras, expected_forests, convert_to_np): + for k in convert_to_np: + if task_extras[k]: + task_extras[k] = np.stack(task_extras[k], axis=0) + else: + task_extras[k] = np.array([]) + forests = a11y_forests.extract_forests_from_task_extras(task_extras) + self.assertEqual(len(forests), len(expected_forests)) + for idx, f in enumerate(forests): + self.assertEqual(f, expected_forests[idx]) + + @parameterized.parameters( + dict(task_extras={}, expected_extras={}), + dict( + task_extras={ + 'no_accessibility_tree': 42, + }, + expected_extras={ + 'no_accessibility_tree': 42, + }, + ), + dict( + task_extras={'accessibility_tree': []}, + expected_extras={'accessibility_tree': []}, + ), + dict( + task_extras={ + 'accessibility_tree': [ + _empty_forest(), + _one_empty_window_forest(), + ], + 'no_accessibility_tree': 43, + }, + expected_extras={ + 'accessibility_tree': [_one_empty_window_forest()], + 'no_accessibility_tree': 43, + }, + ), + dict( + task_extras={ + 'accessibility_tree': [ + _empty_forest(), + _one_empty_window_forest(), + _two_window_forest(), + ] + }, + expected_extras={'accessibility_tree': [_two_window_forest()]}, + ), + dict( + task_extras={ + 'accessibility_tree': [], + 'no_accessibility_tree': 44, + }, + expected_extras={ + 'accessibility_tree': [], + 'no_accessibility_tree': 44, + }, + ), + ) + def test_keep_latest_only(self, task_extras, expected_extras): + if 'accessibility_tree' in task_extras: + if task_extras['accessibility_tree']: + task_extras['accessibility_tree'] = np.stack( + task_extras['accessibility_tree'], axis=0 + ) + else: + task_extras['accessibility_tree'] = np.array([]) + + a11y_forests.keep_latest_forest_only(task_extras) + self.assertSameElements(task_extras.keys(), expected_extras.keys()) + for k in task_extras.keys(): + if k == 'accessibility_tree': + self.assertEqual(len(task_extras[k]), len(expected_extras[k])) + for idx, f in enumerate(task_extras[k]): + self.assertEqual(f, expected_extras[k][idx]) + else: + self.assertEqual(task_extras[k], expected_extras[k]) + pass + + +if __name__ == '__main__': + absltest.main() diff --git a/android_env/components/a11y/a11y_servicer.py b/android_env/components/a11y/a11y_servicer.py new file mode 100644 index 0000000..82e9e25 --- /dev/null +++ b/android_env/components/a11y/a11y_servicer.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Accessibility Servicer implementation.""" + +import asyncio +from collections.abc import AsyncIterator, Generator, Iterable +import threading + +from absl import logging +from android_env.proto.a11y import a11y_pb2 +from android_env.proto.a11y import a11y_pb2_grpc +from android_env.proto.a11y import android_accessibility_forest_pb2 +import grpc + + +class A11yServicer(a11y_pb2_grpc.A11yServiceServicer): + """Services the A11yService requests.""" + + def __init__(self, latest_forest_only: bool = False): + self._received_forests: list[ + android_accessibility_forest_pb2.AndroidAccessibilityForest + ] = [] + self._received_events: list[a11y_pb2.EventRequest] = [] + self._lock_forests = threading.Lock() + self._lock_events = threading.Lock() + self._latest_forest_only = latest_forest_only + self._paused = True + + # A11y Forest bookkeeping. + self._get_forest = asyncio.Event() # Whether to request a forest. + self._forest_ready = asyncio.Event() # Whether the forest is ready. + self._latest_forest: ( + android_accessibility_forest_pb2.AndroidAccessibilityForest | None + ) = None + + def SendForest( + self, + request: android_accessibility_forest_pb2.AndroidAccessibilityForest, + context: grpc.ServicerContext, + ) -> a11y_pb2.ForestResponse: + self._process_forest(request) + return a11y_pb2.ForestResponse() + + def SendEvent( + self, + request: a11y_pb2.EventRequest, + context: grpc.ServicerContext, + ) -> a11y_pb2.EventResponse: + self._process_event(request) + return a11y_pb2.EventResponse() + + async def Bidi( + self, + request_iterator: AsyncIterator[a11y_pb2.ClientToServer], + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[a11y_pb2.ServerToClient]: + """Processes incoming ClientToServer requests.""" + + logging.info('Starting A11yServicer.Bidi()') + + # Send a dummy message to unblock clients in their loop. + yield a11y_pb2.ServerToClient() + + # This block defines two coroutines: + # + # * `read_client_requests()` + # * `check_forest()` + # + # They cooperate with each other and both populate a queue `q` which is + # consumed in a loop below, which actually yields requests which are sent to + # the client. The processing finishes when the clients "closes" the + # connection, which causes `read_client_requests()` to put a special value, + # `STOP_ITERATION`, in the queue. + + # Queue for communicating from coroutines to `Bidi()`. + q = asyncio.Queue() + + should_run = True + + async def read_client_requests(): + """Coroutine for reading client requests.""" + + nonlocal should_run + async for request in request_iterator: + field_name = request.WhichOneof('payload') + match field_name: + case 'event': + self._process_event(request.event) + case 'forest': + self._latest_forest = request.forest + self._forest_ready.set() + self._get_forest.clear() # Reset the `Event`. + case _: + logging.error('Unknown field %r', field_name) + await q.put(a11y_pb2.ServerToClient()) + + # Send a special value to stop processing this `Bidi` connection. + await q.put('STOP_ITERATION') + should_run = False + + async def check_forest(): + """Coroutine for sending "get forest" requests.""" + + nonlocal should_run + while should_run: + await self._get_forest.wait() + await q.put(a11y_pb2.ServerToClient(get_forest={})) + + tasks = asyncio.gather(read_client_requests(), check_forest()) + + while should_run: + v = await q.get() + if v == 'STOP_ITERATION': + break + else: + yield v + + await tasks + + logging.info('Finishing A11yServicer.Bidi()') + + async def get_forest( + self, + ) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None: + """Issues a request to get the a11y forest from the client.""" + + self._get_forest.set() # Unblocks coroutine to send a request. + await self._forest_ready.wait() # Wait for forest to be ready. + self._forest_ready.clear() # Reset the `Event`. + return self._latest_forest + + def gather_forests( + self, + ) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]: + forests = [] + with self._lock_forests: + forests = self._received_forests + self._received_forests = [] + return forests + + def gather_events(self) -> list[a11y_pb2.EventRequest]: + events = [] + with self._lock_events: + events = self._received_events + self._received_events = [] + return events + + def pause_and_clear(self) -> None: + """Temporarily stop receiving events/forests and clear the queue. + + Used when resetting the environment; in this case: + - all events/forests that have been received since last timestep are things + that happened in the last episode after its `LAST` timestep (so we should + ignore them, done by clearing the lists). + - we're about to receive a bunch of events/forests just as a result of + resetting the environment. We don't want to count these either; thus we + temporarily stop receiving new ones. + """ + self._paused = True + with self._lock_forests: + self._received_forests = [] + with self._lock_events: + self._received_events = [] + + def resume(self) -> None: + """Start receiving events/forests (e.g., after a reset).""" + self._paused = False + + def _process_event(self, event: a11y_pb2.EventRequest) -> None: + """Adds the given event to the internal buffer of events.""" + + if not self._paused: + with self._lock_events: + self._received_events.append(event) + + def _process_forest( + self, forest: android_accessibility_forest_pb2.AndroidAccessibilityForest + ) -> None: + """Adds the given forest to the internal buffer of forests.""" + + if not self._paused: + with self._lock_forests: + if self._latest_forest_only: + self._received_forests = [forest] + else: + self._received_forests.append(forest) diff --git a/android_env/components/a11y/a11y_servicer_test.py b/android_env/components/a11y/a11y_servicer_test.py new file mode 100644 index 0000000..bb1d876 --- /dev/null +++ b/android_env/components/a11y/a11y_servicer_test.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for a11y_servicer.""" + +import asyncio +from collections.abc import AsyncIterator, Iterable +from typing import TypeVar +from unittest import IsolatedAsyncioTestCase, mock + +from absl.testing import absltest +from absl.testing import parameterized +from android_env.components.a11y import a11y_servicer +from android_env.proto.a11y import a11y_pb2 +from android_env.proto.a11y import android_accessibility_forest_pb2 +import grpc + + +_T = TypeVar('_T') + + +async def _aiter(xs: Iterable[_T]) -> AsyncIterator[_T]: + """Utility to make an AsyncIterator from Iterable.""" + + for x in xs: + yield x + + +def one_window_one_node_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + window = forest.windows.add() + node = window.tree.nodes.add() + node.class_name = 'foo' + node.is_clickable = True + node.hint_text = 'Foo hint' + return forest + + +def one_window_two_nodes_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + window = forest.windows.add() + node = window.tree.nodes.add() + node.class_name = 'bar' + node.is_clickable = True + node.hint_text = 'Bar hint' + node = window.tree.nodes.add() + node.class_name = 'bar' + node.is_clickable = False + node.hint_text = 'Bar hint 2' + return forest + + +def empty_dict() -> dict[str, str]: + return {} + + +def single_item_dict_with_special_chars() -> dict[str, str]: + return {'foo': 'bar\r\t\nbaz'} + + +class A11yServicerTest(parameterized.TestCase, IsolatedAsyncioTestCase): + + def test_servicer_sendforest(self): + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + servicer = a11y_servicer.A11yServicer() + servicer.resume() + response = servicer.SendForest(one_window_one_node_forest(), mock_context) + self.assertEqual(response.error, '') + response = servicer.SendForest(one_window_two_nodes_forest(), mock_context) + self.assertEqual(response.error, '') + forests = servicer.gather_forests() + self.assertLen(forests, 2) + self.assertEqual(forests[0], one_window_one_node_forest()) + self.assertEqual(forests[1], one_window_two_nodes_forest()) + + async def test_servicer_bidi_forests(self): + """Checks that the bidirectional interface accepts forests.""" + + # Arrange. + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + servicer = a11y_servicer.A11yServicer() + + # Act. + servicer.resume() + responses = [ + x + async for x in servicer.Bidi( + _aiter([ + a11y_pb2.ClientToServer( + event=a11y_pb2.EventRequest( + event=single_item_dict_with_special_chars() + ) + ), + a11y_pb2.ClientToServer(forest=one_window_two_nodes_forest()), + ]), + mock_context, + ) + ] + forest = await servicer.get_forest() + + # Assert. + self.assertEqual(responses[0], a11y_pb2.ServerToClient()) + self.assertEqual(responses[1], a11y_pb2.ServerToClient()) + self.assertIsNotNone(forest) + self.assertEqual(forest, one_window_two_nodes_forest()) + + def test_servicer_sendforest_latest_only(self): + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + servicer = a11y_servicer.A11yServicer(latest_forest_only=True) + servicer.resume() + response = servicer.SendForest(one_window_one_node_forest(), mock_context) + self.assertEqual(response.error, '') + response = servicer.SendForest(one_window_two_nodes_forest(), mock_context) + self.assertEqual(response.error, '') + forests = servicer.gather_forests() + self.assertLen(forests, 1) + self.assertEqual(forests[0], one_window_two_nodes_forest()) + + def test_servicer_sendevent(self): + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + servicer = a11y_servicer.A11yServicer() + servicer.resume() + response = servicer.SendEvent( + a11y_pb2.EventRequest(event=empty_dict()), mock_context + ) + self.assertEqual(response.error, '') + response = servicer.SendEvent( + a11y_pb2.EventRequest(event=single_item_dict_with_special_chars()), + mock_context, + ) + self.assertEqual(response.error, '') + events = servicer.gather_events() + self.assertLen(events, 2) + self.assertEqual(events[0].event, empty_dict()) + self.assertEqual(events[1].event, single_item_dict_with_special_chars()) + + async def test_servicer_bidi_events(self): + """Checks that the bidirectional interface accepts events.""" + + # Arrange. + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + servicer = a11y_servicer.A11yServicer() + + # Act. + servicer.resume() + responses = [ + x + async for x in servicer.Bidi( + _aiter([ + a11y_pb2.ClientToServer( + event=a11y_pb2.EventRequest(event=empty_dict()) + ), + a11y_pb2.ClientToServer( + event=a11y_pb2.EventRequest( + event=single_item_dict_with_special_chars() + ) + ), + ]), + mock_context, + ) + ] + events = servicer.gather_events() + + # Assert. + self.assertEqual(responses[0], a11y_pb2.ServerToClient()) + self.assertEqual(responses[1], a11y_pb2.ServerToClient()) + self.assertLen(events, 2) + self.assertEqual(events[0].event, empty_dict()) + self.assertEqual(events[1].event, single_item_dict_with_special_chars()) + + def test_servicer_pause_and_clear_pauses(self): + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + servicer = a11y_servicer.A11yServicer() + servicer.resume() + servicer.pause_and_clear() + response = servicer.SendEvent( + a11y_pb2.EventRequest(event=empty_dict()), mock_context + ) + self.assertEqual(response.error, '') + response = servicer.SendForest(one_window_one_node_forest(), mock_context) + self.assertEqual(response.error, '') + events = servicer.gather_events() + self.assertEmpty(events) + forests = servicer.gather_forests() + self.assertEmpty(forests) + + def test_servicer_pause_and_clear_clears(self): + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + servicer = a11y_servicer.A11yServicer() + servicer.resume() + response = servicer.SendEvent( + a11y_pb2.EventRequest(event=empty_dict()), mock_context + ) + self.assertEqual(response.error, '') + response = servicer.SendForest(one_window_one_node_forest(), mock_context) + self.assertEqual( + response.error, + '', + ) + servicer.pause_and_clear() + events = servicer.gather_events() + self.assertEmpty(events) + forests = servicer.gather_forests() + self.assertEmpty(forests) + + +if __name__ == '__main__': + absltest.main() diff --git a/android_env/wrappers/a11y_grpc_wrapper.py b/android_env/wrappers/a11y_grpc_wrapper.py new file mode 100644 index 0000000..9cb2678 --- /dev/null +++ b/android_env/wrappers/a11y_grpc_wrapper.py @@ -0,0 +1,505 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps AndroidEnv to retrieve accessibility messages from gRPC.""" + +from concurrent import futures +import time +from typing import Any + +import urllib + +from absl import logging +from android_env import env_interface +from android_env.components import action_type as android_action_type_lib +from android_env.components.a11y import a11y_events +from android_env.components.a11y import a11y_forests +from android_env.components.a11y import a11y_servicer +from android_env.proto import adb_pb2 +from android_env.proto.a11y import a11y_pb2_grpc +from android_env.wrappers import base_wrapper +import dm_env +import grpc +import numpy as np +import portpicker + + +def _get_accessibility_forwarder_apk() -> bytes: + logging.info('Downloading accessibility forwarder apk....') + with urllib.request.urlopen( + 'https://storage.googleapis.com/android_env-tasks/2024.05.13-accessibility_forwarder.apk' + ) as response: + return response.read() + + +class EnableNetworkingError(ValueError): + pass + + +class A11yGrpcWrapper(base_wrapper.BaseWrapper): + """Wrapper which receives A11y events and forests over gRPC. + + A11y forest protobufs and event dicts are sent from the Android emulator via + gRPC from the `AccessibilityForwarder` (for use in developing reward + functions, etc). This wrapper constructs a server which receives these + messages and channels them into `task_extras`. + + The downside of forwarding this information through gRPC is that no messages + will be sent if networking is turned off (e.g., if the AVD is in airplane + mode). To mitigate this problem, the `AccessibilityForwarder` logs an error + message if it fails to contact the server. This wrapper monitors the logs for + such error messages, and attempts (in another thread, to not block environment + transitions) to reconnect the AVD to the network. If this fails to fix the + problem, this wrapper ends the episode. + + This wrapper is implemented to be robust to multiple upstream callers of + `task_extras`, and to ensure they each receive the same extras at every + timestep. Thus, the logic is the following: + * New a11y events/forests are fetched during `reset` and `step`, *not* during + `task_extras()` calls. + * If no one has called `task_extras()` since the last `step` or `reset`, the + extras are accumulated (so that no extras are missed because someone called + `step()` twice without calling `task_extras()`). + * If someone *has* called `task_extras()` since last step, the newly fetched + extras replace the old extras. + """ + + def __init__( + self, + env: env_interface.AndroidEnvInterface, + disable_other_network_traffic: bool = False, + install_a11y_forwarding: bool = False, + start_a11y_service: bool = True, + enable_a11y_tree_info: bool = False, + add_latest_a11y_info_to_obs: bool = False, + a11y_info_timeout: float | None = None, + max_enable_networking_attempts: int = 10, + latest_a11y_info_only: bool = False, + ): + """Initializes wrapper. + + Args: + env: Environment to wrap. + disable_other_network_traffic: When True, all network traffic, other than + the connection to the servicer, is disabled. NOTE: This requires root + access on the device (i.e. it uses the `su` command). An + `AdbControllerError` exception will be raised if the underlying command + fails. + install_a11y_forwarding: If True, the wrapper handles the installation of + all packages required for the servicer to collect a11y information. + start_a11y_service: If True, starts the a11y forwarding services. NOTE: + The packages must be installed beforehand, e.g., using the + install_a11y_forwarding flag. + enable_a11y_tree_info: When False, this wrapper collects only a11y events + and not a11y tree. + add_latest_a11y_info_to_obs: When True, the latest observed a11y forest is + added to the observation. + a11y_info_timeout: When larger than zero and add_latest_a11y_info_to_obs + is set to True, the wrapper will wait the corresponding amount of time, + measured in seconds, to collect the latest a11y forest. + max_enable_networking_attempts: When the a11y gRPC service fails to + provide a11y information, we attempt this many times to re-enable the + networking. If all these attempts fail, fetching task_extras will raise + an EnableNetworkingError. + latest_a11y_info_only: When True, the a11y servicer is setup to save only + the latest tree it has received from the Android app. + """ + self._env = env + if install_a11y_forwarding: + self._install_a11y_forwarding_apk() + time.sleep(10.0) + if start_a11y_service: + self._start_a11y_services() + time.sleep(3.0) + if enable_a11y_tree_info: + self._enable_a11y_tree_logs() + self._relaunch_count = 0 + self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + self._servicer = a11y_servicer.A11yServicer( + latest_forest_only=latest_a11y_info_only + ) + a11y_pb2_grpc.add_A11yServiceServicer_to_server( + self._servicer, self._server + ) + server_credentials = grpc.local_server_credentials() + self._port = portpicker.pick_unused_port() + logging.info('Using port %s', self._port) + uri_address = f'[::]:{self._port}' + self._server.add_secure_port(uri_address, server_credentials) + logging.info('Starting server') + self._server.start() + logging.info('Server now running.') + + # TODO(b/293157187): Ocasionally, the A11yForwarder will fail to provide + # a11y information, which is communicated through task_extras in the form of + # an exception. When this is caused by the agent disabling the network, the + # wrapper will attempt to re-enable networking. If re-enabling the network + # does not solve the issue, there might be other issues in the gRPC + # connection that needs to be resolved. + self._max_enable_networking_attempts = max_enable_networking_attempts + self._reset_enable_networking_attempts() + + self._disable_other_network_traffic = disable_other_network_traffic + self._should_accumulate = False + self._accumulated_extras = None + self._add_latest_a11y_info_to_obs = add_latest_a11y_info_to_obs + self._a11y_info_timeout = a11y_info_timeout + self._parent_action_spec = self._env.action_spec() + if self._a11y_info_timeout is not None and self._a11y_info_timeout > 0.0: + if 'action_type' not in self._parent_action_spec.keys(): + raise ValueError( + 'action_type not in the parent action spec: ' + f'{self._parent_action_spec}. This is a strong requirement when ' + f'a11y_info_timeout = {a11y_info_timeout} > 0' + ) + + def _start_a11y_services(self) -> None: + """Starts the accessibility forwarder services. + + Raises: + RuntimeError: If accessibility service is not started. + """ + start_service_request = adb_pb2.AdbRequest( + settings=adb_pb2.AdbRequest.SettingsRequest( + name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE, + put=adb_pb2.AdbRequest.SettingsRequest.Put( + key='enabled_accessibility_services', + value=( + 'com.google.androidenv.accessibilityforwarder/com.google.' + 'androidenv.accessibilityforwarder.AccessibilityForwarder' + ), + ), + ) + ) + start_service_response = self._env.execute_adb_call(start_service_request) + if start_service_response.status != adb_pb2.AdbResponse.Status.OK: + raise RuntimeError( + 'Could not start accessibility forwarder ' + 'service: ' + f'{start_service_response}.' + ) + + def _install_a11y_forwarding_apk(self) -> None: + """Enables accessibility information forwarding.""" + a11y_fwd_apk = _get_accessibility_forwarder_apk() + # Install and setup the Accesssibility Forwarder. + install_request = adb_pb2.AdbRequest( + install_apk=adb_pb2.AdbRequest.InstallApk( + blob=adb_pb2.AdbRequest.InstallApk.Blob(contents=a11y_fwd_apk), + ) + ) + install_response = self._env.execute_adb_call(install_request) + if install_response.status != adb_pb2.AdbResponse.Status.OK: + raise ValueError( + f'Could not install accessibility_forwarder.apk: {install_response}.' + ) + + def _enable_a11y_tree_logs(self) -> None: + enable_tree_logs_request = adb_pb2.AdbRequest( + send_broadcast=adb_pb2.AdbRequest.SendBroadcast( + action=( + 'accessibility_forwarder.intent.action.' + 'ENABLE_ACCESSIBILITY_TREE_LOGS' + ), + component=( + 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver' + ), + ) + ) + enable_tree_logs_response = self._env.execute_adb_call( + enable_tree_logs_request + ) + if enable_tree_logs_response.status != adb_pb2.AdbResponse.Status.OK: + raise ValueError( + 'Could not enable accessibility tree logging: ' + f'{enable_tree_logs_response}.' + ) + + def _reset_enable_networking_attempts(self) -> None: + self._enable_networking_attepts_left = self._max_enable_networking_attempts + self._enabling_networking_future = None + self._a11y_exception = None + + def get_port(self): + return self._port + + def close(self): + self._server.stop(None) + logging.info('gRPC server stopped') + self._env.close() + + def attempt_enable_networking(self) -> None: + """Attempts to turn on networking within the Android device. + + Attempt to turn on the networking in the Android device, by: + - turning off airplane mode; + - turning on the wifi connection. + """ + self.execute_adb_call( + adb_pb2.AdbRequest( + settings=adb_pb2.AdbRequest.SettingsRequest( + name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, + put=adb_pb2.AdbRequest.SettingsRequest.Put( + key='airplane_mode_on', value='0' + ), + ) + ) + ) + time.sleep(1.0) + self.execute_adb_call( + adb_pb2.AdbRequest( + generic=adb_pb2.AdbRequest.GenericRequest( + args=[ + 'shell', + 'svc', + 'wifi', + 'enable', + ] + ) + ) + ) + time.sleep(1.0) + + def _configure_grpc(self) -> None: + """Configure networking and set the gRPC port in the AVD.""" + + if self._disable_other_network_traffic: + self.execute_adb_call( + adb_pb2.AdbRequest( + generic=adb_pb2.AdbRequest.GenericRequest( + args=[ + 'shell', + 'su', + '0', + 'iptables', + '-A', + 'OUTPUT', + '-p', + 'tcp', + '-d', + '10.0.2.2', + '--dport', + str(self._port), + '-j', + 'ACCEPT', + ] + ) + ) + ) + time.sleep(3.0) + self.execute_adb_call( + adb_pb2.AdbRequest( + generic=adb_pb2.AdbRequest.GenericRequest( + args=[ + 'shell', + 'su', + '0', + 'iptables', + '-A', + 'OUTPUT', + '-j', + 'DROP', + ] + ) + ) + ) + time.sleep(3.0) + + self.execute_adb_call( + adb_pb2.AdbRequest( + settings=adb_pb2.AdbRequest.SettingsRequest( + name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, + put=adb_pb2.AdbRequest.SettingsRequest.Put( + key='no_proxy', value=f'10.0.2.2:{self._port}' + ), + ) + ) + ) + self.attempt_enable_networking() + self.execute_adb_call( + adb_pb2.AdbRequest( + send_broadcast=adb_pb2.AdbRequest.SendBroadcast( + action=( + 'accessibility_forwarder.intent.action.SET_GRPC --ei' + f' "port" {self._port}' + ), + component=( + 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver' + ), + ) + ) + ) + + def _accumulate_and_return_a11y_info( + self, timer: float | None = None, get_env_observation: bool = True + ) -> dict[str, Any]: + """Accumulates and returns the latest a11y tree info and observation. + + Args: + timer: If larger than 0, the system will wait this long for a11y info to + accumulate before it returns a value. + get_env_observation: If False, the corresponding observation is not + introduced here. + + Returns: + a dict with a11y forest under key 'a11y_forest'. All other fields will + provide the observation, if requested. + """ + timer = timer or 0.0 + if timer > 0.0: + time.sleep(timer) + + if get_env_observation: + # Fetch observation. + new_ts = self._env.step({ + 'action_type': np.array( + android_action_type_lib.ActionType.REPEAT, + dtype=self._parent_action_spec['action_type'].dtype, + ), + }) + observation = new_ts.observation + else: + observation = {} + + extras = self.accumulate_new_extras() + forests = a11y_forests.extract_forests_from_task_extras(extras) + if forests: + observation['a11y_forest'] = forests[-1] + else: + observation['a11y_forest'] = None + return observation + + def _fetch_task_extras_and_update_observation( + self, observation: dict[str, Any], timeout: float = 0.0 + ) -> dict[str, Any]: + if timeout > 0.0: + observation = self._accumulate_and_return_a11y_info( + timeout, get_env_observation=True + ) + if not self._add_latest_a11y_info_to_obs: + observation.pop('a11y_forest') + else: + new_obs = self._accumulate_and_return_a11y_info(get_env_observation=False) + if self._add_latest_a11y_info_to_obs: + observation.update(new_obs) + return observation + + def reset(self) -> dm_env.TimeStep: + self._reset_enable_networking_attempts() + self._servicer.pause_and_clear() + timestep = self._env.reset() + self._servicer.resume() + if self._env.stats()['relaunch_count'] > self._relaunch_count: + self._configure_grpc() + self._relaunch_count = self._env.stats()['relaunch_count'] + self._accumulated_extras = {} + timeout = self._a11y_info_timeout or 0.0 + new_observation = self._fetch_task_extras_and_update_observation( + timestep.observation, timeout + ) + timestep = timestep._replace(observation=new_observation) + return timestep + + def step(self, action: Any) -> dm_env.TimeStep: + timeout = float(action.pop('wait_time', self._a11y_info_timeout or 0.0)) + timestep = self._env.step(action) + new_observation = self._fetch_task_extras_and_update_observation( + timestep.observation, timeout=timeout + ) + timestep = timestep._replace(observation=new_observation) + return timestep + + def accumulate_new_extras(self) -> dict[str, Any]: + new_extras = self._fetch_task_extras() + if self._should_accumulate: + for key in new_extras: + if key in self._accumulated_extras: + self._accumulated_extras[key] = np.concatenate( + (self._accumulated_extras[key], new_extras[key]), axis=0 + ) + else: + self._accumulated_extras[key] = new_extras[key] + else: + self._accumulated_extras = new_extras + self._should_accumulate = True + return self._accumulated_extras + + def _fetch_task_extras(self) -> dict[str, Any]: + """Fetches task_extras from the services. + + NOTE: If you want to access the latest a11y information, please use + accumulate_and_return_a11y_info instead. This function has the side effect + of clearing the content from the servicer, hence all the a11y info returned + here won't be accumulated. + + Returns: + A dict with the corresponding task_extras. + + Raises: + EnableNetworkingError: after a fixed number of attempts to revive the a11y + services by re-enabling the network connection. + """ + base_extras = self._env.task_extras(latest_only=False).copy() + if ( + self._enabling_networking_future is None + and 'exception' in base_extras + and base_extras['exception'].shape[0] + ): + self._a11y_exception = base_extras['exception'] + logging.warning( + 'AccessibilityForwarder logged exceptions: %s', self._a11y_exception + ) + if self._enable_networking_attepts_left > 0: + logging.warning( + 'Attempting to enable networking. %s attemps left.', + self._enable_networking_attepts_left - 1, + ) + executor = futures.ThreadPoolExecutor(max_workers=1) + self._enabling_networking_future = executor.submit( + self.attempt_enable_networking + ) + else: + raise EnableNetworkingError( + 'A11y service failed multiple times with' + f' exception.{self._a11y_exception}.' + ) + + if ( + self._enabling_networking_future is not None + and self._enabling_networking_future.done() + ): + self._enabling_networking_future = None + self._enable_networking_attepts_left -= 1 + logging.info('Finished enabling networking.') + + forests = self._servicer.gather_forests() + if forests: + base_extras.update(a11y_forests.package_forests_to_task_extras(forests)) + self._reset_enable_networking_attempts() + events = self._servicer.gather_events() + if events: + base_extras.update(a11y_events.package_events_to_task_extras(events)) + self._reset_enable_networking_attempts() + return base_extras + + def task_extras(self, latest_only: bool = False) -> dict[str, Any]: + if self._accumulated_extras is None: + raise RuntimeError('You must call .reset() before calling .task_extras()') + self._should_accumulate = False + extras = self._accumulated_extras.copy() + if latest_only: + a11y_events.keep_latest_event_only(extras) + a11y_forests.keep_latest_forest_only(extras) + return extras diff --git a/android_env/wrappers/a11y_grpc_wrapper_test.py b/android_env/wrappers/a11y_grpc_wrapper_test.py new file mode 100644 index 0000000..d8ecf7b --- /dev/null +++ b/android_env/wrappers/a11y_grpc_wrapper_test.py @@ -0,0 +1,851 @@ +# coding=utf-8 +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for a11y_grpc_wrapper.""" + +import time +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from android_env import env_interface +from android_env.proto import adb_pb2 +from android_env.proto.a11y import a11y_pb2 +from android_env.proto.a11y import a11y_pb2_grpc +from android_env.proto.a11y import android_accessibility_forest_pb2 +from android_env.wrappers import a11y_grpc_wrapper +import dm_env +import grpc +import numpy as np + + +def empty_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + return android_accessibility_forest_pb2.AndroidAccessibilityForest() + + +def one_empty_window_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + _ = forest.windows.add() + return forest + + +def one_window_one_node_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + window = forest.windows.add() + node = window.tree.nodes.add() + node.class_name = 'foo' + node.is_clickable = True + node.hint_text = 'Foo hint' + return forest + + +def one_window_two_nodes_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + window = forest.windows.add() + node = window.tree.nodes.add() + node.class_name = 'bar' + node.is_clickable = True + node.hint_text = 'Bar hint' + node = window.tree.nodes.add() + node.class_name = 'bar' + node.is_clickable = False + node.hint_text = 'Bar hint 2' + return forest + + +def three_windows_forest() -> ( + android_accessibility_forest_pb2.AndroidAccessibilityForest +): + forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() + _ = forest.windows.add() + window = forest.windows.add() + node = window.tree.nodes.add() + node.class_name = 'foo' + node.is_clickable = True + node.hint_text = 'hint' + window = forest.windows.add() + node = window.tree.nodes.add() + node.class_name = 'baz' + node.is_clickable = True + node.hint_text = 'hint' + node = window.tree.nodes.add() + node.class_name = 'foobar' + node.is_clickable = False + node.hint_text = 'hint' + return forest + + +def empty_dict() -> dict[str, str]: + return {} + + +def single_item_dict() -> dict[str, str]: + return {'foo': 'bar'} + + +def several_long_items_dict() -> dict[str, str]: + return { + 'first_key': 'Lorem ipsum ' * 100, + 'second_key': 'the beginning is the end is' * 100, + } + + +def single_item_dict_with_special_chars() -> dict[str, str]: + return {'foo': 'bar\r\t\nbaz'} + + +def _ok_response(): + return adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) + + +class A11yGrpcWrapperTest(parameterized.TestCase): + + def test_server(self): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.task_extras.return_value = {} + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + wrapped_env.reset() + channel_creds = grpc.local_channel_credentials() + with grpc.secure_channel( + f'[::]:{wrapped_env.get_port()}', channel_creds + ) as channel: + grpc.channel_ready_future(channel).result() + stub = a11y_pb2_grpc.A11yServiceStub(channel) + stub.SendForest(one_window_one_node_forest()) + stub.SendForest(one_window_two_nodes_forest()) + wrapped_env.step({}) + extras = wrapped_env.task_extras(latest_only=False) + self.assertIn('accessibility_tree', extras) + self.assertEqual(extras['accessibility_tree'].shape[0], 2) + + # tests of fetch_task_extras: + # exception occurs (ensure attempt to enable networking) and recovers + # exception occurs and enable networking doesn't help + # exception occurs twice but with a forest sent between + + @parameterized.named_parameters( + ('no_events_or_forests', [], []), + ( + 'no_events', + [], + [one_window_one_node_forest(), one_window_two_nodes_forest()], + ), + ('no_forests', [empty_dict(), single_item_dict()], []), + ( + 'events_and_forests', + [empty_dict(), single_item_dict()], + [one_window_one_node_forest(), one_window_two_nodes_forest()], + ), + ) + @mock.patch.object(time, 'sleep', autospec=True) + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_fetch_task_extras( + self, + received_events, + received_forests, + mock_server, + mock_add_servicer, + mock_sleep, + ): + del mock_server, mock_add_servicer, mock_sleep + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.task_extras.return_value = { + 'foo': np.array(['bar', 'baz'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + } + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + wrapped_env.reset() + for forest in received_forests: + wrapped_env._servicer.SendForest(forest, mock_context) + for event in received_events: + wrapped_env._servicer.SendEvent( + a11y_pb2.EventRequest(event=event), mock_context + ) + with mock.patch.object( + wrapped_env, 'attempt_enable_networking' + ) as mock_attempt_enable_networking: + extras = wrapped_env._fetch_task_extras() + mock_attempt_enable_networking.assert_not_called() + self.assertIn('foo', extras) + np.testing.assert_array_equal(extras['foo'], ['bar', 'baz']) + self.assertIn('some_key', extras) + np.testing.assert_array_equal(extras['some_key'], ['some_value']) + if received_events: + self.assertIn('full_event', extras) + self.assertLen(extras['full_event'], len(received_events)) + for i, event in enumerate(received_events): + event = a11y_pb2.EventRequest(event=event) + np.testing.assert_array_equal(extras['full_event'][i], event) + else: + self.assertNotIn('full_event', extras) + if received_forests: + self.assertIn('accessibility_tree', extras) + self.assertLen(extras['accessibility_tree'], len(received_forests)) + for i, forest in enumerate(received_forests): + np.testing.assert_array_equal(extras['accessibility_tree'][i], forest) + else: + self.assertNotIn('accessibility_tree', extras) + + @mock.patch.object(time, 'sleep', autospec=True) + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_fetch_task_extras_enable_networking( + self, + mock_server, + mock_add_servicer, + mock_sleep, + ): + del mock_server, mock_add_servicer, mock_sleep + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + 'exception': np.array(['fake exception'], dtype='U'), + } + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + with mock.patch.object( + wrapped_env, 'attempt_enable_networking' + ) as mock_attempt_enable_networking: + extras = wrapped_env._fetch_task_extras() + self.assertNotIn('accessibility_tree', extras) + self.assertNotIn('full_event', extras) + future = wrapped_env._enabling_networking_future + if future is not None: + future.result() + mock_attempt_enable_networking.assert_called_once() + + @mock.patch.object(time, 'sleep', autospec=True) + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_fetch_task_extras_enable_networking_twice( + self, + mock_server, + mock_add_servicer, + mock_sleep, + ): + del mock_server, mock_add_servicer, mock_sleep + mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + } + + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + wrapped_env.reset() + + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + 'exception': np.array(['fake exception'], dtype='U'), + } + with mock.patch.object( + wrapped_env, 'attempt_enable_networking' + ) as mock_attempt_enable_networking: + extras = wrapped_env._fetch_task_extras() + self.assertNotIn('accessibility_tree', extras) + self.assertNotIn('full_event', extras) + future = wrapped_env._enabling_networking_future + if future is not None: + future.result() + mock_attempt_enable_networking.assert_called_once() + # Fixed networking; send a forest so the wrapper knows it worked. + wrapped_env._servicer.SendForest(one_window_one_node_forest(), mock_context) + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + } + extras = wrapped_env._fetch_task_extras() + self.assertIn('accessibility_tree', extras) + self.assertEqual(extras['accessibility_tree'].shape[0], 1) + self.assertEqual( + extras['accessibility_tree'][0], one_window_one_node_forest() + ) + + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + 'exception': np.array(['fake exception'], dtype='U'), + } + with mock.patch.object( + wrapped_env, 'attempt_enable_networking' + ) as mock_attempt_enable_networking: + extras = wrapped_env._fetch_task_extras() + self.assertNotIn('accessibility_tree', extras) + self.assertNotIn('full_event', extras) + future = wrapped_env._enabling_networking_future + if future is not None: + future.result() + mock_attempt_enable_networking.assert_called_once() + + @mock.patch.object(time, 'sleep', autospec=True) + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_task_extras_rasises_with_a11y_info_exception( + self, + mock_sleep, + mock_add_servicer, + mock_server, + ): + del mock_server, mock_add_servicer, mock_sleep + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + } + + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) + base_env.step.return_value = dm_env.transition( + observation={'dummy': 42}, reward=0.0 + ) + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, + add_latest_a11y_info_to_obs=True, + max_enable_networking_attempts=1, + ) + wrapped_env.reset() + + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + 'exception': np.array(['fake exception'], dtype='U'), + } + with mock.patch.object( + wrapped_env, 'attempt_enable_networking' + ) as mock_attempt_enable_networking: + extras = wrapped_env._fetch_task_extras() + self.assertNotIn('accessibility_tree', extras) + self.assertNotIn('full_event', extras) + future = wrapped_env._enabling_networking_future + if future is not None: + future.result() + mock_attempt_enable_networking.assert_called_once() + # The _fetch_task_extras() call inside the next step will force a restart + self.assertRaises( + a11y_grpc_wrapper.EnableNetworkingError, wrapped_env.step, {} + ) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_configure_grpc( + self, + mock_server, + mock_add_servicer, + ): + del mock_server, mock_add_servicer + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.task_extras.return_value = { + 'foo': np.array(['bar'], dtype='U'), + 'some_key': np.array(['some_value'], dtype='U'), + } + + base_env.stats.return_value = {'relaunch_count': 1} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + with mock.patch.object( + wrapped_env, '_configure_grpc' + ) as mock_configure_grpc: + wrapped_env.reset() + mock_configure_grpc.assert_called_once() + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_task_extras_raises_before_reset( + self, unused_mock_server, unused_mock_add_servicer + ): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + with self.assertRaisesRegex( + RuntimeError, + r'You must call \.reset\(\) before calling \.task_extras\(\)', + ): + wrapped_env.task_extras(latest_only=False) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_extras_accumulate_between_steps( + self, mock_server, mock_add_servicer + ): + del mock_server, mock_add_servicer + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) + base_env.step.return_value = dm_env.transition( + observation={'dummy': 42}, reward=0.0 + ) + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, add_latest_a11y_info_to_obs=True + ) + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + wrapped_env._fetch_task_extras.return_value = { + 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), + 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), + } + timestep = wrapped_env.reset() + self.assertIn('a11y_forest', timestep.observation) + self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) + wrapped_env._fetch_task_extras.return_value = { + 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), + 'accessibility_tree': np.array( + one_window_two_nodes_forest(), ndmin=1, dtype=object + ), + } + timestep = wrapped_env.step({}) + self.assertIn('a11y_forest', timestep.observation) + self.assertEqual( + timestep.observation['a11y_forest'], one_window_two_nodes_forest() + ) + timestep = wrapped_env.step({}) + self.assertIn('a11y_forest', timestep.observation) + self.assertEqual( + timestep.observation['a11y_forest'], one_window_two_nodes_forest() + ) + wrapped_env._fetch_task_extras.return_value = { + 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), + } + timestep = wrapped_env.step({}) + self.assertIn('a11y_forest', timestep.observation) + self.assertEqual( + timestep.observation['a11y_forest'], one_window_two_nodes_forest() + ) + expected_task_extras = { + 'full_event': np.array( + [ + single_item_dict(), + empty_dict(), + empty_dict(), + single_item_dict(), + ], + dtype=object, + ), + 'accessibility_tree': np.array( + [ + empty_forest(), + one_window_two_nodes_forest(), + one_window_two_nodes_forest(), + ], + dtype=object, + ), + } + expected_task_extras_latest = { + 'full_event': np.array([single_item_dict()], dtype=object), + 'accessibility_tree': np.array( + [one_window_two_nodes_forest()], dtype=object + ), + } + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + + task_extras = wrapped_env.task_extras(latest_only=True) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras_latest['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras_latest['accessibility_tree'], + ) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_a11y_info_disabled( + self, + unused_mock_server, + unused_mock_add_servicer, + ): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.action_spec.return_value = { + 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) + } + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) + base_env.step.return_value = dm_env.transition( + observation={'dummy': 42}, reward=0.0 + ) + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, add_latest_a11y_info_to_obs=False, a11y_info_timeout=1.0 + ) + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + wrapped_env._fetch_task_extras.return_value = { + 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), + } + timestep = wrapped_env.reset() + self.assertNotIn('a11y_forest', timestep.observation) + timestep = wrapped_env.step({}) + self.assertNotIn('a11y_forest', timestep.observation) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_a11y_info_with_timer_info_present( + self, + unused_mock_server, + unused_mock_add_servicer, + ): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.action_spec.return_value = { + 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) + } + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) + base_env.step.return_value = dm_env.transition( + observation={'dummy': 42}, reward=0.0 + ) + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0 + ) + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + wrapped_env._fetch_task_extras.side_effect = [{ + 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), + }] + timestep = wrapped_env.reset() + self.assertIn('a11y_forest', timestep.observation) + self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + @mock.patch.object(time, 'sleep', autospec=True) + def test_a11y_info_with_timer_task_extra_returned( + self, unused_mock_server, unused_mock_add_servicer, unused_mock_sleep + ): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.action_spec.return_value = { + 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) + } + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) + base_env.step.return_value = dm_env.transition( + observation={'dummy': 42}, reward=0.0 + ) + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0 + ) + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + wrapped_env._fetch_task_extras.side_effect = [ + { + 'accessibility_tree': np.array( + empty_forest(), ndmin=1, dtype=object + ), + }, + ] + timestep = wrapped_env.reset() + self.assertIn('a11y_forest', timestep.observation) + self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + @mock.patch.object(time, 'sleep', autospec=True) + def test_a11y_info_with_timer_from_action( + self, unused_mock_server, unused_mock_add_servicer, mock_sleep + ): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.action_spec.return_value = { + 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) + } + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) + base_env.step.return_value = dm_env.transition( + observation={'dummy': 42}, reward=0.0 + ) + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=0.0 + ) + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + wrapped_env._fetch_task_extras.side_effect = [ + { + 'accessibility_tree': np.array( + empty_forest(), ndmin=1, dtype=object + ), + }, + ] + timestep = wrapped_env.step(action={'wait_time': 1.0}) + self.assertIn('a11y_forest', timestep.observation) + mock_sleep.assert_called_once() + self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_task_extras_same_between_calls(self, mock_server, mock_add_servicer): + del mock_server, mock_add_servicer + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + expected_task_extras = { + 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), + 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), + } + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + wrapped_env._fetch_task_extras.return_value = expected_task_extras + wrapped_env.reset() + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + + expected_task_extras = { + 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), + 'accessibility_tree': np.array( + one_window_two_nodes_forest(), ndmin=1, dtype=object + ), + } + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + wrapped_env._fetch_task_extras.return_value = expected_task_extras + wrapped_env.step({}) + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + + @mock.patch.object( + a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True + ) + @mock.patch.object(grpc, 'server', autospec=True) + def test_task_extras_clear_if_called_between_step( + self, mock_server, mock_add_servicer + ): + del mock_server, mock_add_servicer + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + base_env.stats.return_value = {'relaunch_count': 0} + base_env.execute_adb_call.return_value = _ok_response() + wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) + with mock.patch.object(wrapped_env, '_fetch_task_extras'): + expected_task_extras = { + 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), + 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), + } + wrapped_env._fetch_task_extras.return_value = expected_task_extras + wrapped_env.reset() + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + + expected_task_extras = { + 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), + 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), + } + wrapped_env._fetch_task_extras.return_value = expected_task_extras + wrapped_env.step({}) + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + expected_task_extras = { + 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), + 'accessibility_tree': np.array( + one_window_two_nodes_forest(), ndmin=1, dtype=object + ), + } + wrapped_env._fetch_task_extras.return_value = expected_task_extras + wrapped_env.step({}) + task_extras = wrapped_env.task_extras(latest_only=False) + np.testing.assert_equal( + task_extras['full_event'], expected_task_extras['full_event'] + ) + np.testing.assert_equal( + task_extras['accessibility_tree'], + expected_task_extras['accessibility_tree'], + ) + + @parameterized.named_parameters( + ('none_true', False, False, False, 0), + ('only_install', True, False, False, 1), + ('only_start', False, True, False, 1), + ('only_enable_a11y_tree', False, False, True, 1), + ('install_and_start_no_a11y_tree', True, True, False, 2), + ('install_and_a11y_tree', True, False, True, 2), + ('start_and_a11y_tree', False, True, True, 2), + ('all_true', True, True, True, 3), + ) + @mock.patch.object(time, 'sleep', autospec=True) + def test_apk_install_and_start( + self, + install_a11y_forwarding: bool, + start_a11y_service: bool, + enable_a11y_tree_logs: bool, + expected_adb_calls: int, + unused_mock_sleep, + ): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + + side_effects = [] + if install_a11y_forwarding: + side_effects.append(_ok_response()) # install response + if start_a11y_service: + side_effects.append(_ok_response()) # start service response + if enable_a11y_tree_logs: + side_effects.append(_ok_response()) # enable_tree_request + + base_env.execute_adb_call.side_effect = side_effects + + _ = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, + install_a11y_forwarding=install_a11y_forwarding, + start_a11y_service=start_a11y_service, + enable_a11y_tree_info=enable_a11y_tree_logs, + ) + self.assertEqual(base_env.execute_adb_call.call_count, expected_adb_calls) + + @mock.patch.object(time, 'sleep', autospec=True) + def test_component_and_start(self, unused_mock_sleep): + base_env = mock.create_autospec( + env_interface.AndroidEnvInterface, instance=True + ) + + side_effects = [] + side_effects.append(_ok_response()) # install response + side_effects.append(_ok_response()) # start service response + side_effects.append(_ok_response()) # enable_tree_request + + base_env.execute_adb_call.side_effect = side_effects + + _ = a11y_grpc_wrapper.A11yGrpcWrapper( + base_env, + install_a11y_forwarding=True, + start_a11y_service=True, + enable_a11y_tree_info=True, + ) + + # call_args returns a tuple of which the first member is a tuple containing + # the most recent args the mock was called with, and execute_adb_call only + # has one arg (so [0][0] to access the AdbRequest). + self.assertEqual( + base_env.execute_adb_call.call_args[0][0].send_broadcast.component, + 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver', + ) + + +if __name__ == '__main__': + absltest.main()