diff --git a/android_env/components/coordinator.py b/android_env/components/coordinator.py index f1ddc98..21af0b6 100644 --- a/android_env/components/coordinator.py +++ b/android_env/components/coordinator.py @@ -27,9 +27,9 @@ from android_env.components import adb_call_parser from android_env.components import config_classes from android_env.components import errors +from android_env.components import pixel_fns from android_env.components import specs from android_env.components import task_manager as task_manager_lib -from android_env.components import utils from android_env.components.simulators import base_simulator from android_env.proto import adb_pb2 from android_env.proto import state_pb2 @@ -92,7 +92,8 @@ def action_spec(self) -> dict[str, dm_env.specs.Array]: def observation_spec(self) -> dict[str, dm_env.specs.Array]: return specs.base_observation_spec( - height=self._screen_size[0], width=self._screen_size[1]) + height=self._screen_size[0], width=self._screen_size[1] + ) def _update_screen_size(self) -> None: """Sets the screen size from a screenshot ignoring the color channel.""" @@ -109,7 +110,9 @@ def _update_device_orientation(self) -> None: orientation_response = self._adb_call_parser.parse( adb_pb2.AdbRequest( - get_orientation=adb_pb2.AdbRequest.GetOrientationRequest())) + get_orientation=adb_pb2.AdbRequest.GetOrientationRequest() + ) + ) if orientation_response.status != adb_pb2.AdbResponse.Status.OK: logging.error('Got bad orientation: %r', orientation_response) return @@ -180,7 +183,8 @@ def _launch_simulator(self, max_retries: int = 3): while True: if num_tries > max_retries: raise errors.TooManyRestartsError( - 'Maximum number of restart attempts reached.') from latest_error + 'Maximum number of restart attempts reached.' + ) from latest_error logging.info('Simulator launch attempt %d of %d', num_tries, max_retries) self._task_manager.stop() @@ -263,7 +267,11 @@ def _update_settings(self) -> None: settings=adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, put=adb_pb2.AdbRequest.SettingsRequest.Put( - key='policy_control', value=policy_control_value)))) + key='policy_control', value=policy_control_value + ), + ) + ) + ) def _create_adb_call_parser(self): """Creates a new AdbCallParser instance.""" @@ -332,8 +340,11 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]: # Get current timestamp and update the delta. now = time.time() - timestamp_delta = (0 if self._latest_observation_time == 0 else - (now - self._latest_observation_time) * 1e6) + timestamp_delta = ( + 0 + if self._latest_observation_time == 0 + else (now - self._latest_observation_time) * 1e6 + ) self._latest_observation_time = now # Grab pixels. @@ -409,10 +420,12 @@ def _prepare_touch_action( width_height = self._screen_size[::-1] for i, finger_action in enumerate(self._split_touch_action(action)): is_touch = ( - finger_action['action_type'] == action_type_lib.ActionType.TOUCH) + finger_action['action_type'] == action_type_lib.ActionType.TOUCH + ) touch_position = finger_action['touch_position'] - touch_pixels = utils.touch_position_to_pixel_position( - touch_position, width_height=width_height) + touch_pixels = pixel_fns.touch_position_to_pixel_position( + touch_position, width_height=width_height + ) touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i)) return touch_events @@ -495,8 +508,9 @@ def close(self): class InteractionThread(threading.Thread): """A thread that interacts with a simulator.""" - def __init__(self, simulator: base_simulator.BaseSimulator, - interaction_rate_sec: float): + def __init__( + self, simulator: base_simulator.BaseSimulator, interaction_rate_sec: float + ): super().__init__() self._simulator = simulator self._interaction_rate_sec = interaction_rate_sec diff --git a/android_env/components/utils.py b/android_env/components/pixel_fns.py similarity index 100% rename from android_env/components/utils.py rename to android_env/components/pixel_fns.py diff --git a/android_env/components/utils_test.py b/android_env/components/pixel_fns_test.py similarity index 80% rename from android_env/components/utils_test.py rename to android_env/components/pixel_fns_test.py index 8a7d0cf..8202430 100644 --- a/android_env/components/utils_test.py +++ b/android_env/components/pixel_fns_test.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for android_env.components.utils.""" +"""Tests for pixel_fns.""" from absl.testing import absltest from absl.testing import parameterized -from android_env.components import utils +from android_env.components import pixel_fns from dm_env import specs import numpy as np @@ -32,38 +32,42 @@ class UtilsTest(parameterized.TestCase): ) def test_touch_position_to_pixel_position( self, touch_pos, width_height, pixel_pos): - self.assertEqual(utils.touch_position_to_pixel_position( - np.array(touch_pos), width_height), pixel_pos) + self.assertEqual( + pixel_fns.touch_position_to_pixel_position( + np.array(touch_pos), width_height + ), + pixel_pos, + ) def test_transpose_pixels(self): image = np.reshape(np.array(range(12)), (3, 2, 2)) expected = [[[0, 1], [4, 5], [8, 9]], [[2, 3], [6, 7], [10, 11]]] - self.assertEqual(utils.transpose_pixels(image).shape, (2, 3, 2)) - self.assertTrue((utils.transpose_pixels(image) == expected).all()) + self.assertEqual(pixel_fns.transpose_pixels(image).shape, (2, 3, 2)) + self.assertTrue((pixel_fns.transpose_pixels(image) == expected).all()) def test_orient_pixels(self): image = np.reshape(np.array(range(12)), (3, 2, 2)) expected_90 = [[[8, 9], [4, 5], [0, 1]], [[10, 11], [6, 7], [2, 3]]] rot_90 = 1 # LANDSCAPE_90 - rotated = utils.orient_pixels(image, rot_90) + rotated = pixel_fns.orient_pixels(image, rot_90) self.assertEqual(rotated.shape, (2, 3, 2)) self.assertTrue((rotated == expected_90).all()) expected_180 = [[[10, 11], [8, 9]], [[6, 7], [4, 5]], [[2, 3], [0, 1]]] rot_180 = 2 # PORTRAIT_180 - rotated = utils.orient_pixels(image, rot_180) + rotated = pixel_fns.orient_pixels(image, rot_180) self.assertEqual(rotated.shape, (3, 2, 2)) self.assertTrue((rotated == expected_180).all()) expected_270 = [[[2, 3], [6, 7], [10, 11]], [[0, 1], [4, 5], [8, 9]]] rot_270 = 3 # LANDSCAPE_270 - rotated = utils.orient_pixels(image, rot_270) + rotated = pixel_fns.orient_pixels(image, rot_270) self.assertEqual(rotated.shape, (2, 3, 2)) self.assertTrue((rotated == expected_270).all()) rot_0 = 0 # PORTRAIT_0 - rotated = utils.orient_pixels(image, rot_0) + rotated = pixel_fns.orient_pixels(image, rot_0) self.assertEqual(rotated.shape, (3, 2, 2)) self.assertTrue((rotated == image).all()) @@ -75,7 +79,7 @@ def test_convert_int_to_float_bounded_array(self): maximum=[5, 5, 20, 2], name='bounded_array') data = np.array([2, 2, 10, 0], dtype=np.int32) - float_data = utils.convert_int_to_float(data, spec) + float_data = pixel_fns.convert_int_to_float(data, spec) np.testing.assert_equal( np.array([2.0 / 5.0, 1.0 / 4.0, 0.0, 0.5], dtype=np.float32), float_data ) @@ -84,7 +88,7 @@ def test_convert_int_to_float_bounded_array_broadcast(self): spec = specs.BoundedArray( shape=(3,), dtype=np.int16, minimum=2, maximum=4, name='bounded_array') data = np.array([2, 3, 4], dtype=np.int16) - float_data = utils.convert_int_to_float(data, spec) + float_data = pixel_fns.convert_int_to_float(data, spec) np.testing.assert_equal( np.array([0.0, 0.5, 1.0], dtype=np.float32), float_data) @@ -94,7 +98,7 @@ def test_convert_int_to_float_no_bounds(self): dtype=np.int8, # int8 implies min=-128, max=127 name='bounded_array') data = np.array([-128, 0, 127], dtype=np.int16) - float_data = utils.convert_int_to_float(data, spec) + float_data = pixel_fns.convert_int_to_float(data, spec) np.testing.assert_equal( np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data) diff --git a/android_env/wrappers/float_pixels_wrapper.py b/android_env/wrappers/float_pixels_wrapper.py index 9617207..43def1f 100644 --- a/android_env/wrappers/float_pixels_wrapper.py +++ b/android_env/wrappers/float_pixels_wrapper.py @@ -15,7 +15,7 @@ """Converts pixel observation to from int to float32 between 0.0 and 1.0.""" -from android_env.components import utils +from android_env.components import pixel_fns from android_env.wrappers import base_wrapper import dm_env from dm_env import specs @@ -35,7 +35,7 @@ def _process_observation( self, observation: dict[str, np.ndarray] ) -> dict[str, np.ndarray]: if self._should_convert_int_to_float: - float_pixels = utils.convert_int_to_float( + float_pixels = pixel_fns.convert_int_to_float( observation['pixels'], self._input_spec ) observation['pixels'] = float_pixels diff --git a/android_env/wrappers/last_action_wrapper.py b/android_env/wrappers/last_action_wrapper.py index 33bfd4b..a09633c 100644 --- a/android_env/wrappers/last_action_wrapper.py +++ b/android_env/wrappers/last_action_wrapper.py @@ -16,7 +16,7 @@ """Extends Android observation with the latest action taken.""" from android_env.components import action_type -from android_env.components import utils +from android_env.components import pixel_fns from android_env.wrappers import base_wrapper import dm_env from dm_env import specs @@ -76,8 +76,9 @@ def _get_last_action_layer(self, pixels: np.ndarray) -> np.ndarray: if ('action_type' in last_action and last_action['action_type'] == action_type.ActionType.TOUCH): touch_position = last_action['touch_position'] - x, y = utils.touch_position_to_pixel_position( - touch_position, width_height=self._screen_dimensions[::-1]) + x, y = pixel_fns.touch_position_to_pixel_position( + touch_position, width_height=self._screen_dimensions[::-1] + ) last_action_layer[y, x] = 255 return last_action_layer diff --git a/examples/run_human_agent.py b/examples/run_human_agent.py index 66c944b..505a729 100644 --- a/examples/run_human_agent.py +++ b/examples/run_human_agent.py @@ -24,7 +24,7 @@ from android_env import loader from android_env.components import action_type from android_env.components import config_classes -from android_env.components import utils +from android_env.components import pixel_fns import dm_env import numpy as np import pygame @@ -116,8 +116,8 @@ def _render_pygame_frame(surface: pygame.Surface, screen: pygame.Surface, """Displays latest observation on pygame surface.""" frame = timestep.observation['pixels'][:, :, :3] # (H x W x C) (RGB) - frame = utils.transpose_pixels(frame) # (W x H x C) - frame = utils.orient_pixels(frame, orientation) + frame = pixel_fns.transpose_pixels(frame) # (W x H x C) + frame = pixel_fns.orient_pixels(frame, orientation) pygame.surfarray.blit_array(surface, frame) pygame.transform.smoothscale(surface, screen.get_size(), screen)