From 1a7016f659f6f637e505879968439b87cb6949b5 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH] Rebase fixes for service pool PR. --- compiler_gym/envs/compiler_env.py | 238 ------------------ .../service/client_service_compiler_env.py | 45 +++- 2 files changed, 33 insertions(+), 250 deletions(-) diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index 58bd5ad4be..9fcdc59249 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -65,244 +65,6 @@ def observation_space_spec(self) -> ObservationSpaceSpec: def observation_space_spec( self, observation_space_spec: Optional[ObservationSpaceSpec] ): -<<<<<<< HEAD -======= - """Construct and initialize a CompilerGym environment. - - In normal use you should use :code:`gym.make(...)` rather than calling - the constructor directly. - - :param service: The hostname and port of a service that implements the - CompilerGym service interface, or the path of a binary file which - provides the CompilerGym service interface when executed. See - :doc:`/compiler_gym/service` for details. - - :param rewards: The reward spaces that this environment supports. - Rewards are typically calculated based on observations generated by - the service. See :class:`Reward ` for - details. - - :param benchmark: The benchmark to use for this environment. Either a - URI string, or a :class:`Benchmark - ` instance. If not provided, the - first benchmark as returned by - :code:`next(env.datasets.benchmarks())` will be used as the default. - - :param observation_space: Compute and return observations at each - :func:`step()` from this space. Accepts a string name or an - :class:`ObservationSpaceSpec - `. If not provided, - :func:`step()` returns :code:`None` for the observation value. Can - be set later using :meth:`env.observation_space - `. For available - spaces, see :class:`env.observation.spaces - `. - - :param reward_space: Compute and return reward at each :func:`step()` - from this space. Accepts a string name or a :class:`Reward - `. If not provided, :func:`step()` - returns :code:`None` for the reward value. Can be set later using - :meth:`env.reward_space - `. For available spaces, - see :class:`env.reward.spaces `. - - :param action_space: The name of the action space to use. If not - specified, the default action space for this compiler is used. - - :param derived_observation_spaces: An optional list of arguments to be - passed to :meth:`env.observation.add_derived_space() - `. - - :param connection_settings: The settings used to establish a connection - with the remote service. - - :param service_connection: An existing compiler gym service connection - to use. - - :param service_pool: A service pool to use for acquiring a service - connection. If not specified, the :meth:`global service pool - ` is used. - - :raises FileNotFoundError: If service is a path to a file that is not - found. - - :raises TimeoutError: If the compiler service fails to initialize within - the parameters provided in :code:`connection_settings`. - """ - # NOTE(cummins): Logger argument deprecated and scheduled to be removed - # in release 0.2.3. - if logger: - warnings.warn( - "The `logger` argument is deprecated on CompilerEnv.__init__() " - "and will be removed in a future release. All CompilerEnv " - "instances share a logger named compiler_gym.envs.compiler_env", - DeprecationWarning, - ) - - self.metadata = {"render.modes": ["human", "ansi"]} - - # A compiler service supports multiple simultaneous environments. This - # session ID is used to identify this environment. - self._session_id: Optional[int] = None - - self._service_endpoint: Union[str, Path] = service - self._connection_settings = connection_settings or ConnectionOpts() - - if service_connection is None: - self._service_pool: Optional[ServiceConnectionPool] = ( - ServiceConnectionPool.get() if service_pool is None else service_pool - ) - self.service = self._service_pool.acquire( - endpoint=self._service_endpoint, - opts=self._connection_settings, - ) - else: - self._service_pool: Optional[ServiceConnectionPool] = service_pool - self.service = service_connection - - self.datasets = Datasets(datasets or []) - - self.action_space_name = action_space - - # If no reward space is specified, generate some from numeric observation spaces - rewards = rewards or [ - DefaultRewardFromObservation(obs.name) - for obs in self.service.observation_spaces - if obs.default_observation.WhichOneof("value") - and isinstance( - getattr( - obs.default_observation, obs.default_observation.WhichOneof("value") - ), - numbers.Number, - ) - ] - - # The benchmark that is currently being used, and the benchmark that - # will be used on the next call to reset(). These are equal except in - # the gap between the user setting the env.benchmark property while in - # an episode and the next call to env.reset(). - self._benchmark_in_use: Optional[Benchmark] = None - self._benchmark_in_use_proto: BenchmarkProto = BenchmarkProto() - self._next_benchmark: Optional[Benchmark] = None - # Normally when the benchmark is changed the updated value is not - # reflected until the next call to reset(). We make an exception for the - # constructor-time benchmark as otherwise the behavior of the benchmark - # property is counter-intuitive: - # - # >>> env = gym.make("example-v0", benchmark="foo") - # >>> env.benchmark - # None - # >>> env.reset() - # >>> env.benchmark - # "foo" - # - # By forcing the _benchmark_in_use URI at constructor time, the first - # env.benchmark above returns the benchmark as expected. - try: - self.benchmark = benchmark or next(self.datasets.benchmarks()) - self._benchmark_in_use = self._next_benchmark - except StopIteration: - # StopIteration raised on next(self.datasets.benchmarks()) if there - # are no benchmarks available. This is to allow CompilerEnv to be - # used without any datasets by setting a benchmark before/during the - # first reset() call. - pass - - # Process the available action, observation, and reward spaces. - self.action_spaces = [ - proto_to_action_space(space) for space in self.service.action_spaces - ] - - self.observation = self._observation_view_type( - raw_step=self.raw_step, - spaces=self.service.observation_spaces, - ) - self.reward = self._reward_view_type(rewards, self.observation) - - # Register any derived observation spaces now so that the observation - # space can be set below. - for derived_observation_space in derived_observation_spaces or []: - self.observation.add_derived_space_internal(**derived_observation_space) - - # Lazily evaluated version strings. - self._versions: Optional[GetVersionReply] = None - - self.action_space: Optional[Space] = None - self.observation_space: Optional[Space] = None - - # Mutable state initialized in reset(). - self.reward_range: Tuple[float, float] = (-np.inf, np.inf) - self.episode_reward: Optional[float] = None - self.episode_start_time: float = time() - self.actions: List[ActionType] = [] - - # Initialize the default observation/reward spaces. - self.observation_space_spec: Optional[ObservationSpaceSpec] = None - self.reward_space_spec: Optional[Reward] = None - self.observation_space = observation_space - self.reward_space = reward_space - - @property - @deprecated( - version="0.2.1", - reason=( - "The `CompilerEnv.logger` attribute is deprecated. All CompilerEnv " - "instances share a logger named compiler_gym.envs.compiler_env" - ), - ) - def logger(self): - return _logger - - @property - def versions(self) -> GetVersionReply: - """Get the version numbers from the compiler service.""" - if self._versions is None: - self._versions = self.service( - self.service.stub.GetVersion, GetVersionRequest() - ) - return self._versions - - @property - def version(self) -> str: - """The version string of the compiler service.""" - return self.versions.service_version - - @property - def compiler_version(self) -> str: - """The version string of the underlying compiler that this service supports.""" - return self.versions.compiler_version - - def commandline(self) -> str: - """Interface for :class:`CompilerEnv ` - subclasses to provide an equivalent commandline invocation to the - current environment state. - - See also :meth:`commandline_to_actions() - `. - - Calling this method on a :class:`CompilerEnv - ` instance raises - :code:`NotImplementedError`. - - :return: A string commandline invocation. - """ - raise NotImplementedError("abstract method") - - def commandline_to_actions(self, commandline: str) -> List[ActionType]: - """Interface for :class:`CompilerEnv ` - subclasses to convert from a commandline invocation to a sequence of - actions. - - See also :meth:`commandline() - `. - - Calling this method on a :class:`CompilerEnv - ` instance raises - :code:`NotImplementedError`. - - :return: A list of actions. - """ ->>>>>>> 4a874cee (Fix typo in docstring.) raise NotImplementedError("abstract method") @property diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index cbb8d31857..d2af1b53eb 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -14,7 +14,6 @@ from time import time from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -from compiler_gym.service.connection_pool import ServiceConnectionPool import numpy as np from deprecated.sphinx import deprecated from gym.spaces import Space @@ -32,6 +31,10 @@ SessionNotFound, ) from compiler_gym.service.connection import ServiceIsClosed +from compiler_gym.service.connection_pool import ( + ServiceConnectionPool, + ServiceConnectionPoolBase, +) from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest from compiler_gym.service.proto import Benchmark as BenchmarkProto from compiler_gym.service.proto import ( @@ -136,6 +139,7 @@ def __init__( reward_space: Optional[Union[str, Reward]] = None, action_space: Optional[str] = None, derived_observation_spaces: Optional[List[Dict[str, Any]]] = None, + service_message_converters: ServiceMessageConverters = None, connection_settings: Optional[ConnectionOpts] = None, service_connection: Optional[CompilerGymServiceConnection] = None, service_pool: Optional[ServiceConnectionPool] = None, @@ -187,6 +191,9 @@ def __init__( passed to :meth:`env.observation.add_derived_space() `. + :param service_message_converters: Custom converters for action spaces + and actions. + :param connection_settings: The settings used to establish a connection with the remote service. @@ -234,7 +241,7 @@ def __init__( self._service_pool: Optional[ServiceConnectionPoolBase] = service_pool self.service = service_connection - self.datasets = Datasets(datasets or []) + self._datasets = Datasets(datasets or []) self.action_space_name = action_space @@ -282,9 +289,16 @@ def __init__( # first reset() call. pass + self.service_message_converters = ( + ServiceMessageConverters() + if service_message_converters is None + else service_message_converters + ) + # Process the available action, observation, and reward spaces. self.action_spaces = [ - proto_to_action_space(space) for space in self.service.action_spaces + self.service_message_converters.action_space_converter(space) + for space in self.service.action_spaces ] self.observation = self._observation_view_type( @@ -308,7 +322,7 @@ def __init__( self.reward_range: Tuple[float, float] = (-np.inf, np.inf) self.episode_reward: Optional[float] = None self.episode_start_time: float = time() - self.actions: List[ActionType] = [] + self._actions: List[ActionType] = [] # Initialize the default observation/reward spaces. self.observation_space_spec: Optional[ObservationSpaceSpec] = None @@ -544,10 +558,11 @@ def _init_kwargs(self) -> Dict[str, Any]: "benchmark": self.benchmark, "connection_settings": self._connection_settings, "service": self._service_endpoint, + "service_pool": self._service_pool, } def fork(self) -> "ClientServiceCompilerEnv": - if not self.in_episode: + if not self.in_episode: actions = self.actions.copy() self.reset() if actions: @@ -603,7 +618,7 @@ def fork(self) -> "ClientServiceCompilerEnv": # Copy over the mutable episode state. new_env.episode_reward = self.episode_reward new_env.episode_start_time = self.episode_start_time - new_env.actions = self.actions.copy() + new_env._actions = self.actions.copy() # pylint: disable=protected-access return new_env @@ -698,7 +713,7 @@ def reset( # pylint: disable=arguments-differ does not have a default benchmark to select from. """ - def _retry(error) -> Optional[ObservationType]: + def _retry(error) -> Optional[ObservationType]: """Abort and retry on error.""" # Log the error that we are recovering from, but treat # ServiceIsClosed errors as unimportant since we know what causes @@ -837,11 +852,13 @@ def _call_with_error( self.observation.session_id = reply.session_id self.reward.get_cost = self.observation.__getitem__ self.episode_start_time = time() - self.actions = [] + self._actions: List[ActionType] = [] # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = proto_to_action_space(reply.new_action_space) + self.action_space = self.service_message_converters.action_space_converter( + reply.new_action_space + ) self.reward.reset(benchmark=self.benchmark, observation_view=self.observation) if self.reward_space: @@ -905,12 +922,14 @@ def raw_step( } # Record the actions. - self.actions += actions + self._actions += actions # Send the request to the backend service. request = StepRequest( session_id=self._session_id, - action=[Event(int64_value=a) for a in actions], + action=[ + self.service_message_converters.action_converter(a) for a in actions + ], observation_space=[ observation_space.index for observation_space in observations_to_compute ], @@ -954,7 +973,9 @@ def raw_step( # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = proto_to_action_space(reply.new_action_space) + self.action_space = self.service_message_converters.action_space_converter( + reply.new_action_space + ) # Translate observations to python representations. if len(reply.observation) != len(observations_to_compute):