Skip to content

Commit

Permalink
Rebase fixes for service pool PR.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Apr 20, 2022
1 parent 2210f50 commit 1a7016f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 250 deletions.
238 changes: 0 additions & 238 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,244 +65,6 @@ def observation_space_spec(self) -> ObservationSpaceSpec:
def observation_space_spec(
self, observation_space_spec: Optional[ObservationSpaceSpec]
):
<<<<<<< HEAD
=======
"""Construct and initialize a CompilerGym environment.

In normal use you should use :code:`gym.make(...)` rather than calling
the constructor directly.

:param service: The hostname and port of a service that implements the
CompilerGym service interface, or the path of a binary file which
provides the CompilerGym service interface when executed. See
:doc:`/compiler_gym/service` for details.

:param rewards: The reward spaces that this environment supports.
Rewards are typically calculated based on observations generated by
the service. See :class:`Reward <compiler_gym.spaces.Reward>` for
details.

:param benchmark: The benchmark to use for this environment. Either a
URI string, or a :class:`Benchmark
<compiler_gym.datasets.Benchmark>` instance. If not provided, the
first benchmark as returned by
:code:`next(env.datasets.benchmarks())` will be used as the default.

:param observation_space: Compute and return observations at each
:func:`step()` from this space. Accepts a string name or an
:class:`ObservationSpaceSpec
<compiler_gym.views.ObservationSpaceSpec>`. If not provided,
:func:`step()` returns :code:`None` for the observation value. Can
be set later using :meth:`env.observation_space
<compiler_gym.envs.CompilerEnv.observation_space>`. For available
spaces, see :class:`env.observation.spaces
<compiler_gym.views.ObservationView>`.

:param reward_space: Compute and return reward at each :func:`step()`
from this space. Accepts a string name or a :class:`Reward
<compiler_gym.spaces.Reward>`. If not provided, :func:`step()`
returns :code:`None` for the reward value. Can be set later using
:meth:`env.reward_space
<compiler_gym.envs.CompilerEnv.reward_space>`. For available spaces,
see :class:`env.reward.spaces <compiler_gym.views.RewardView>`.

:param action_space: The name of the action space to use. If not
specified, the default action space for this compiler is used.

:param derived_observation_spaces: An optional list of arguments to be
passed to :meth:`env.observation.add_derived_space()
<compiler_gym.views.observation.Observation.add_derived_space>`.

:param connection_settings: The settings used to establish a connection
with the remote service.

:param service_connection: An existing compiler gym service connection
to use.

:param service_pool: A service pool to use for acquiring a service
connection. If not specified, the :meth:`global service pool
<compiler_gym.service.ServiceConnectionPool.get>` is used.

:raises FileNotFoundError: If service is a path to a file that is not
found.

:raises TimeoutError: If the compiler service fails to initialize within
the parameters provided in :code:`connection_settings`.
"""
# NOTE(cummins): Logger argument deprecated and scheduled to be removed
# in release 0.2.3.
if logger:
warnings.warn(
"The `logger` argument is deprecated on CompilerEnv.__init__() "
"and will be removed in a future release. All CompilerEnv "
"instances share a logger named compiler_gym.envs.compiler_env",
DeprecationWarning,
)

self.metadata = {"render.modes": ["human", "ansi"]}

# A compiler service supports multiple simultaneous environments. This
# session ID is used to identify this environment.
self._session_id: Optional[int] = None

self._service_endpoint: Union[str, Path] = service
self._connection_settings = connection_settings or ConnectionOpts()

if service_connection is None:
self._service_pool: Optional[ServiceConnectionPool] = (
ServiceConnectionPool.get() if service_pool is None else service_pool
)
self.service = self._service_pool.acquire(
endpoint=self._service_endpoint,
opts=self._connection_settings,
)
else:
self._service_pool: Optional[ServiceConnectionPool] = service_pool
self.service = service_connection

self.datasets = Datasets(datasets or [])

self.action_space_name = action_space

# If no reward space is specified, generate some from numeric observation spaces
rewards = rewards or [
DefaultRewardFromObservation(obs.name)
for obs in self.service.observation_spaces
if obs.default_observation.WhichOneof("value")
and isinstance(
getattr(
obs.default_observation, obs.default_observation.WhichOneof("value")
),
numbers.Number,
)
]

# The benchmark that is currently being used, and the benchmark that
# will be used on the next call to reset(). These are equal except in
# the gap between the user setting the env.benchmark property while in
# an episode and the next call to env.reset().
self._benchmark_in_use: Optional[Benchmark] = None
self._benchmark_in_use_proto: BenchmarkProto = BenchmarkProto()
self._next_benchmark: Optional[Benchmark] = None
# Normally when the benchmark is changed the updated value is not
# reflected until the next call to reset(). We make an exception for the
# constructor-time benchmark as otherwise the behavior of the benchmark
# property is counter-intuitive:
#
# >>> env = gym.make("example-v0", benchmark="foo")
# >>> env.benchmark
# None
# >>> env.reset()
# >>> env.benchmark
# "foo"
#
# By forcing the _benchmark_in_use URI at constructor time, the first
# env.benchmark above returns the benchmark as expected.
try:
self.benchmark = benchmark or next(self.datasets.benchmarks())
self._benchmark_in_use = self._next_benchmark
except StopIteration:
# StopIteration raised on next(self.datasets.benchmarks()) if there
# are no benchmarks available. This is to allow CompilerEnv to be
# used without any datasets by setting a benchmark before/during the
# first reset() call.
pass

# Process the available action, observation, and reward spaces.
self.action_spaces = [
proto_to_action_space(space) for space in self.service.action_spaces
]

self.observation = self._observation_view_type(
raw_step=self.raw_step,
spaces=self.service.observation_spaces,
)
self.reward = self._reward_view_type(rewards, self.observation)

# Register any derived observation spaces now so that the observation
# space can be set below.
for derived_observation_space in derived_observation_spaces or []:
self.observation.add_derived_space_internal(**derived_observation_space)

# Lazily evaluated version strings.
self._versions: Optional[GetVersionReply] = None

self.action_space: Optional[Space] = None
self.observation_space: Optional[Space] = None

# Mutable state initialized in reset().
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
self.episode_reward: Optional[float] = None
self.episode_start_time: float = time()
self.actions: List[ActionType] = []

# Initialize the default observation/reward spaces.
self.observation_space_spec: Optional[ObservationSpaceSpec] = None
self.reward_space_spec: Optional[Reward] = None
self.observation_space = observation_space
self.reward_space = reward_space

@property
@deprecated(
version="0.2.1",
reason=(
"The `CompilerEnv.logger` attribute is deprecated. All CompilerEnv "
"instances share a logger named compiler_gym.envs.compiler_env"
),
)
def logger(self):
return _logger

@property
def versions(self) -> GetVersionReply:
"""Get the version numbers from the compiler service."""
if self._versions is None:
self._versions = self.service(
self.service.stub.GetVersion, GetVersionRequest()
)
return self._versions

@property
def version(self) -> str:
"""The version string of the compiler service."""
return self.versions.service_version

@property
def compiler_version(self) -> str:
"""The version string of the underlying compiler that this service supports."""
return self.versions.compiler_version

def commandline(self) -> str:
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
subclasses to provide an equivalent commandline invocation to the
current environment state.

See also :meth:`commandline_to_actions()
<compiler_gym.envs.CompilerEnv.commandline_to_actions>`.

Calling this method on a :class:`CompilerEnv
<compiler_gym.envs.CompilerEnv>` instance raises
:code:`NotImplementedError`.

:return: A string commandline invocation.
"""
raise NotImplementedError("abstract method")

def commandline_to_actions(self, commandline: str) -> List[ActionType]:
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
subclasses to convert from a commandline invocation to a sequence of
actions.

See also :meth:`commandline()
<compiler_gym.envs.CompilerEnv.commandline>`.

Calling this method on a :class:`CompilerEnv
<compiler_gym.envs.CompilerEnv>` instance raises
:code:`NotImplementedError`.

:return: A list of actions.
"""
>>>>>>> 4a874cee (Fix typo in docstring.)
raise NotImplementedError("abstract method")

@property
Expand Down
45 changes: 33 additions & 12 deletions compiler_gym/service/client_service_compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from time import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from compiler_gym.service.connection_pool import ServiceConnectionPool
import numpy as np
from deprecated.sphinx import deprecated
from gym.spaces import Space
Expand All @@ -32,6 +31,10 @@
SessionNotFound,
)
from compiler_gym.service.connection import ServiceIsClosed
from compiler_gym.service.connection_pool import (
ServiceConnectionPool,
ServiceConnectionPoolBase,
)
from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest
from compiler_gym.service.proto import Benchmark as BenchmarkProto
from compiler_gym.service.proto import (
Expand Down Expand Up @@ -136,6 +139,7 @@ def __init__(
reward_space: Optional[Union[str, Reward]] = None,
action_space: Optional[str] = None,
derived_observation_spaces: Optional[List[Dict[str, Any]]] = None,
service_message_converters: ServiceMessageConverters = None,
connection_settings: Optional[ConnectionOpts] = None,
service_connection: Optional[CompilerGymServiceConnection] = None,
service_pool: Optional[ServiceConnectionPool] = None,
Expand Down Expand Up @@ -187,6 +191,9 @@ def __init__(
passed to :meth:`env.observation.add_derived_space()
<compiler_gym.views.observation.Observation.add_derived_space>`.
:param service_message_converters: Custom converters for action spaces
and actions.
:param connection_settings: The settings used to establish a connection
with the remote service.
Expand Down Expand Up @@ -234,7 +241,7 @@ def __init__(
self._service_pool: Optional[ServiceConnectionPoolBase] = service_pool
self.service = service_connection

self.datasets = Datasets(datasets or [])
self._datasets = Datasets(datasets or [])

self.action_space_name = action_space

Expand Down Expand Up @@ -282,9 +289,16 @@ def __init__(
# first reset() call.
pass

self.service_message_converters = (
ServiceMessageConverters()
if service_message_converters is None
else service_message_converters
)

# Process the available action, observation, and reward spaces.
self.action_spaces = [
proto_to_action_space(space) for space in self.service.action_spaces
self.service_message_converters.action_space_converter(space)
for space in self.service.action_spaces
]

self.observation = self._observation_view_type(
Expand All @@ -308,7 +322,7 @@ def __init__(
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
self.episode_reward: Optional[float] = None
self.episode_start_time: float = time()
self.actions: List[ActionType] = []
self._actions: List[ActionType] = []

# Initialize the default observation/reward spaces.
self.observation_space_spec: Optional[ObservationSpaceSpec] = None
Expand Down Expand Up @@ -544,10 +558,11 @@ def _init_kwargs(self) -> Dict[str, Any]:
"benchmark": self.benchmark,
"connection_settings": self._connection_settings,
"service": self._service_endpoint,
"service_pool": self._service_pool,
}

def fork(self) -> "ClientServiceCompilerEnv":
if not self.in_episode:
if not self.in_episode:
actions = self.actions.copy()
self.reset()
if actions:
Expand Down Expand Up @@ -603,7 +618,7 @@ def fork(self) -> "ClientServiceCompilerEnv":
# Copy over the mutable episode state.
new_env.episode_reward = self.episode_reward
new_env.episode_start_time = self.episode_start_time
new_env.actions = self.actions.copy()
new_env._actions = self.actions.copy() # pylint: disable=protected-access

return new_env

Expand Down Expand Up @@ -698,7 +713,7 @@ def reset( # pylint: disable=arguments-differ
does not have a default benchmark to select from.
"""

def _retry(error) -> Optional[ObservationType]:
def _retry(error) -> Optional[ObservationType]:
"""Abort and retry on error."""
# Log the error that we are recovering from, but treat
# ServiceIsClosed errors as unimportant since we know what causes
Expand Down Expand Up @@ -837,11 +852,13 @@ def _call_with_error(
self.observation.session_id = reply.session_id
self.reward.get_cost = self.observation.__getitem__
self.episode_start_time = time()
self.actions = []
self._actions: List[ActionType] = []

# If the action space has changed, update it.
if reply.HasField("new_action_space"):
self.action_space = proto_to_action_space(reply.new_action_space)
self.action_space = self.service_message_converters.action_space_converter(
reply.new_action_space
)

self.reward.reset(benchmark=self.benchmark, observation_view=self.observation)
if self.reward_space:
Expand Down Expand Up @@ -905,12 +922,14 @@ def raw_step(
}

# Record the actions.
self.actions += actions
self._actions += actions

# Send the request to the backend service.
request = StepRequest(
session_id=self._session_id,
action=[Event(int64_value=a) for a in actions],
action=[
self.service_message_converters.action_converter(a) for a in actions
],
observation_space=[
observation_space.index for observation_space in observations_to_compute
],
Expand Down Expand Up @@ -954,7 +973,9 @@ def raw_step(

# If the action space has changed, update it.
if reply.HasField("new_action_space"):
self.action_space = proto_to_action_space(reply.new_action_space)
self.action_space = self.service_message_converters.action_space_converter(
reply.new_action_space
)

# Translate observations to python representations.
if len(reply.observation) != len(observations_to_compute):
Expand Down

0 comments on commit 1a7016f

Please sign in to comment.