Skip to content

Commit

Permalink
Rename components/utils.py to pixel_fns.py.
Browse files Browse the repository at this point in the history
The functions in this file are only related to manipulating screenshots, and
the name `utils.py` is uninformative. `pixel_fns.py` is short and is easier to
understand.

PiperOrigin-RevId: 669374500
  • Loading branch information
kenjitoyama authored and copybara-github committed Sep 3, 2024
1 parent dd539fe commit 6d286f2
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 33 deletions.
38 changes: 26 additions & 12 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components import errors
from android_env.components import pixel_fns
from android_env.components import specs
from android_env.components import task_manager as task_manager_lib
from android_env.components import utils
from android_env.components.simulators import base_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
Expand Down Expand Up @@ -92,7 +92,8 @@ def action_spec(self) -> dict[str, dm_env.specs.Array]:

def observation_spec(self) -> dict[str, dm_env.specs.Array]:
return specs.base_observation_spec(
height=self._screen_size[0], width=self._screen_size[1])
height=self._screen_size[0], width=self._screen_size[1]
)

def _update_screen_size(self) -> None:
"""Sets the screen size from a screenshot ignoring the color channel."""
Expand All @@ -109,7 +110,9 @@ def _update_device_orientation(self) -> None:

orientation_response = self._adb_call_parser.parse(
adb_pb2.AdbRequest(
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()))
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()
)
)
if orientation_response.status != adb_pb2.AdbResponse.Status.OK:
logging.error('Got bad orientation: %r', orientation_response)
return
Expand Down Expand Up @@ -180,7 +183,8 @@ def _launch_simulator(self, max_retries: int = 3):
while True:
if num_tries > max_retries:
raise errors.TooManyRestartsError(
'Maximum number of restart attempts reached.') from latest_error
'Maximum number of restart attempts reached.'
) from latest_error
logging.info('Simulator launch attempt %d of %d', num_tries, max_retries)

self._task_manager.stop()
Expand Down Expand Up @@ -263,7 +267,11 @@ def _update_settings(self) -> None:
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='policy_control', value=policy_control_value))))
key='policy_control', value=policy_control_value
),
)
)
)

def _create_adb_call_parser(self):
"""Creates a new AdbCallParser instance."""
Expand Down Expand Up @@ -332,8 +340,11 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]:

# Get current timestamp and update the delta.
now = time.time()
timestamp_delta = (0 if self._latest_observation_time == 0 else
(now - self._latest_observation_time) * 1e6)
timestamp_delta = (
0
if self._latest_observation_time == 0
else (now - self._latest_observation_time) * 1e6
)
self._latest_observation_time = now

# Grab pixels.
Expand Down Expand Up @@ -409,10 +420,12 @@ def _prepare_touch_action(
width_height = self._screen_size[::-1]
for i, finger_action in enumerate(self._split_touch_action(action)):
is_touch = (
finger_action['action_type'] == action_type_lib.ActionType.TOUCH)
finger_action['action_type'] == action_type_lib.ActionType.TOUCH
)
touch_position = finger_action['touch_position']
touch_pixels = utils.touch_position_to_pixel_position(
touch_position, width_height=width_height)
touch_pixels = pixel_fns.touch_position_to_pixel_position(
touch_position, width_height=width_height
)
touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i))
return touch_events

Expand Down Expand Up @@ -495,8 +508,9 @@ def close(self):
class InteractionThread(threading.Thread):
"""A thread that interacts with a simulator."""

def __init__(self, simulator: base_simulator.BaseSimulator,
interaction_rate_sec: float):
def __init__(
self, simulator: base_simulator.BaseSimulator, interaction_rate_sec: float
):
super().__init__()
self._simulator = simulator
self._interaction_rate_sec = interaction_rate_sec
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for android_env.components.utils."""
"""Tests for pixel_fns."""

from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import utils
from android_env.components import pixel_fns
from dm_env import specs
import numpy as np

Expand All @@ -32,38 +32,42 @@ class UtilsTest(parameterized.TestCase):
)
def test_touch_position_to_pixel_position(
self, touch_pos, width_height, pixel_pos):
self.assertEqual(utils.touch_position_to_pixel_position(
np.array(touch_pos), width_height), pixel_pos)
self.assertEqual(
pixel_fns.touch_position_to_pixel_position(
np.array(touch_pos), width_height
),
pixel_pos,
)

def test_transpose_pixels(self):
image = np.reshape(np.array(range(12)), (3, 2, 2))
expected = [[[0, 1], [4, 5], [8, 9]], [[2, 3], [6, 7], [10, 11]]]
self.assertEqual(utils.transpose_pixels(image).shape, (2, 3, 2))
self.assertTrue((utils.transpose_pixels(image) == expected).all())
self.assertEqual(pixel_fns.transpose_pixels(image).shape, (2, 3, 2))
self.assertTrue((pixel_fns.transpose_pixels(image) == expected).all())

def test_orient_pixels(self):
image = np.reshape(np.array(range(12)), (3, 2, 2))

expected_90 = [[[8, 9], [4, 5], [0, 1]], [[10, 11], [6, 7], [2, 3]]]
rot_90 = 1 # LANDSCAPE_90
rotated = utils.orient_pixels(image, rot_90)
rotated = pixel_fns.orient_pixels(image, rot_90)
self.assertEqual(rotated.shape, (2, 3, 2))
self.assertTrue((rotated == expected_90).all())

expected_180 = [[[10, 11], [8, 9]], [[6, 7], [4, 5]], [[2, 3], [0, 1]]]
rot_180 = 2 # PORTRAIT_180
rotated = utils.orient_pixels(image, rot_180)
rotated = pixel_fns.orient_pixels(image, rot_180)
self.assertEqual(rotated.shape, (3, 2, 2))
self.assertTrue((rotated == expected_180).all())

expected_270 = [[[2, 3], [6, 7], [10, 11]], [[0, 1], [4, 5], [8, 9]]]
rot_270 = 3 # LANDSCAPE_270
rotated = utils.orient_pixels(image, rot_270)
rotated = pixel_fns.orient_pixels(image, rot_270)
self.assertEqual(rotated.shape, (2, 3, 2))
self.assertTrue((rotated == expected_270).all())

rot_0 = 0 # PORTRAIT_0
rotated = utils.orient_pixels(image, rot_0)
rotated = pixel_fns.orient_pixels(image, rot_0)
self.assertEqual(rotated.shape, (3, 2, 2))
self.assertTrue((rotated == image).all())

Expand All @@ -75,7 +79,7 @@ def test_convert_int_to_float_bounded_array(self):
maximum=[5, 5, 20, 2],
name='bounded_array')
data = np.array([2, 2, 10, 0], dtype=np.int32)
float_data = utils.convert_int_to_float(data, spec)
float_data = pixel_fns.convert_int_to_float(data, spec)
np.testing.assert_equal(
np.array([2.0 / 5.0, 1.0 / 4.0, 0.0, 0.5], dtype=np.float32), float_data
)
Expand All @@ -84,7 +88,7 @@ def test_convert_int_to_float_bounded_array_broadcast(self):
spec = specs.BoundedArray(
shape=(3,), dtype=np.int16, minimum=2, maximum=4, name='bounded_array')
data = np.array([2, 3, 4], dtype=np.int16)
float_data = utils.convert_int_to_float(data, spec)
float_data = pixel_fns.convert_int_to_float(data, spec)
np.testing.assert_equal(
np.array([0.0, 0.5, 1.0], dtype=np.float32), float_data)

Expand All @@ -94,7 +98,7 @@ def test_convert_int_to_float_no_bounds(self):
dtype=np.int8, # int8 implies min=-128, max=127
name='bounded_array')
data = np.array([-128, 0, 127], dtype=np.int16)
float_data = utils.convert_int_to_float(data, spec)
float_data = pixel_fns.convert_int_to_float(data, spec)
np.testing.assert_equal(
np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data)

Expand Down
4 changes: 2 additions & 2 deletions android_env/wrappers/float_pixels_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""Converts pixel observation to from int to float32 between 0.0 and 1.0."""

from android_env.components import utils
from android_env.components import pixel_fns
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
Expand All @@ -35,7 +35,7 @@ def _process_observation(
self, observation: dict[str, np.ndarray]
) -> dict[str, np.ndarray]:
if self._should_convert_int_to_float:
float_pixels = utils.convert_int_to_float(
float_pixels = pixel_fns.convert_int_to_float(
observation['pixels'], self._input_spec
)
observation['pixels'] = float_pixels
Expand Down
7 changes: 4 additions & 3 deletions android_env/wrappers/last_action_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Extends Android observation with the latest action taken."""

from android_env.components import action_type
from android_env.components import utils
from android_env.components import pixel_fns
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
Expand Down Expand Up @@ -76,8 +76,9 @@ def _get_last_action_layer(self, pixels: np.ndarray) -> np.ndarray:
if ('action_type' in last_action and
last_action['action_type'] == action_type.ActionType.TOUCH):
touch_position = last_action['touch_position']
x, y = utils.touch_position_to_pixel_position(
touch_position, width_height=self._screen_dimensions[::-1])
x, y = pixel_fns.touch_position_to_pixel_position(
touch_position, width_height=self._screen_dimensions[::-1]
)
last_action_layer[y, x] = 255

return last_action_layer
Expand Down
6 changes: 3 additions & 3 deletions examples/run_human_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from android_env import loader
from android_env.components import action_type
from android_env.components import config_classes
from android_env.components import utils
from android_env.components import pixel_fns
import dm_env
import numpy as np
import pygame
Expand Down Expand Up @@ -116,8 +116,8 @@ def _render_pygame_frame(surface: pygame.Surface, screen: pygame.Surface,
"""Displays latest observation on pygame surface."""

frame = timestep.observation['pixels'][:, :, :3] # (H x W x C) (RGB)
frame = utils.transpose_pixels(frame) # (W x H x C)
frame = utils.orient_pixels(frame, orientation)
frame = pixel_fns.transpose_pixels(frame) # (W x H x C)
frame = pixel_fns.orient_pixels(frame, orientation)

pygame.surfarray.blit_array(surface, frame)
pygame.transform.smoothscale(surface, screen.get_size(), screen)
Expand Down

0 comments on commit 6d286f2

Please sign in to comment.