diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ac836fe9e..8bd8db03c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -59,7 +59,6 @@ jobs: else wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz -O ~/llvm.tar.xz fi - # TODO(cummins): Remove 'v' debugging flag: mkdir ~/llvm && tar xvf ~/llvm.tar.xz --strip-components 1 -C ~/llvm rm ~/llvm.tar.xz echo "Unpacked, testing for expected file:" diff --git a/benchmarks/bench_test.py b/benchmarks/bench_test.py index 497f73e53..373e30491 100644 --- a/benchmarks/bench_test.py +++ b/benchmarks/bench_test.py @@ -20,9 +20,10 @@ import gym import pytest -import examples.example_compiler_gym_service as dummy +import examples.example_compiler_gym_service # noqa Environment import. +import examples.example_compiler_gym_service as dummy # noqa Environment import. from compiler_gym.envs import CompilerEnv, LlvmEnv, llvm -from compiler_gym.service import CompilerGymServiceConnection +from compiler_gym.service import CompilerGymServiceConnection, ConnectionOpts from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv from tests.pytest_plugins.llvm import OBSERVATION_SPACE_NAMES, REWARD_SPACE_NAMES from tests.test_main import main @@ -46,10 +47,10 @@ def env(request) -> CompilerEnv: @pytest.mark.parametrize( "env_id", - ["llvm-v0", "example-cc-v0", "example-py-v0"], - ids=["llvm", "dummy-cc", "dummy-py"], + ["llvm-v0", "example-cc-v0", "example-py-v0", "loop_tool-v0"], + ids=["llvm", "dummy-cc", "dummy-py", "loop_tool"], ) -def test_make_local(benchmark, env_id): +def test_make_env(benchmark, env_id): benchmark(lambda: gym.make(env_id).close()) @@ -64,7 +65,7 @@ def test_make_local(benchmark, env_id): ) def test_make_service(benchmark, args): service_binary, env_class = args - service = CompilerGymServiceConnection(service_binary) + service = CompilerGymServiceConnection(service_binary, ConnectionOpts()) try: benchmark(lambda: env_class(service=service.connection.url).close()) finally: diff --git a/compiler_gym/bin/service.py b/compiler_gym/bin/service.py index f49e9f7e6..9bdfd1c97 100644 --- a/compiler_gym/bin/service.py +++ b/compiler_gym/bin/service.py @@ -266,7 +266,9 @@ def main(argv): if FLAGS.run_on_port: assert FLAGS.env, "Must specify an --env to run" - settings = ConnectionOpts(script_args=["--port", str(FLAGS.run_on_port)]) + settings = ConnectionOpts( + script_args=frozenset(["--port", str(FLAGS.run_on_port)]) + ) with gym.make(FLAGS.env, connection_settings=settings) as env: print( f"=== Started a service on port {FLAGS.run_on_port}. Use C-c to terminate. ===" diff --git a/compiler_gym/envs/gcc/__init__.py b/compiler_gym/envs/gcc/__init__.py index a17c4f95b..c55fd8ccf 100644 --- a/compiler_gym/envs/gcc/__init__.py +++ b/compiler_gym/envs/gcc/__init__.py @@ -16,7 +16,7 @@ register( id="gcc-v0", - entry_point="compiler_gym.envs.gcc:GccEnv", + entry_point="compiler_gym.envs.gcc.gcc_env:make", kwargs={"service": GCC_SERVICE_BINARY}, ) diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index fa77837a1..245446659 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -7,6 +7,7 @@ import json import pickle from pathlib import Path +from threading import Lock from typing import Any, Dict, List, Optional, Union from compiler_gym.datasets import Benchmark @@ -15,6 +16,7 @@ from compiler_gym.envs.gcc.gcc_rewards import AsmSizeReward, ObjSizeReward from compiler_gym.service import ConnectionOpts from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv +from compiler_gym.service.connection_pool import ServiceConnectionPoolBase from compiler_gym.spaces import Reward from compiler_gym.util.decorators import memoized_property from compiler_gym.util.gym_type_hints import ObservationType, OptionalArgumentValue @@ -63,9 +65,11 @@ def __init__( :raises ServiceInitError: If the requested GCC version cannot be used. """ - connection_settings = connection_settings or ConnectionOpts() # Pass the executable path via an environment variable - connection_settings.script_env = {"CC": gcc_bin} + connection_settings = connection_settings or ConnectionOpts() + connection_settings.script_env = connection_settings.script_env.set( + "CC", gcc_bin + ) # Eagerly create a GCC compiler instance now because: # @@ -76,6 +80,13 @@ def __init__( # initialization may time out. Gcc(bin=gcc_bin) + # NOTE(github.com/facebookresearch/CompilerGym/pull/583): The GCC + # environment stalls on the StartSession() RPC call when service + # connection caching is enabled. I believe this has something to do with + # the runtime code generation, but have not been able to diagnose it + # yet. For now, disable service connection caching for GCC environments. + kwargs["service_pool"] = ServiceConnectionPoolBase() + super().__init__( *args, **kwargs, @@ -88,6 +99,9 @@ def __init__( ) self._timeout = timeout + def commandline_to_actions(self, commandline: str) -> List[int]: + return NotImplementedError + def reset( self, benchmark: Optional[Union[str, Benchmark]] = None, @@ -213,3 +227,19 @@ def _init_kwargs(self) -> Dict[str, Any]: "gcc_bin": self.gcc_spec.gcc.bin, **super()._init_kwargs(), } + + +_GCC_ENV_DOCKER_CONSTRUCTOR_LOCK = Lock() + + +def make(*args, gcc_bin: Union[str, Path] = DEFAULT_GCC, **kwargs): + """Construct a GccEnv class using a lock to ensure thread exclusivity. + + This is to prevent multiple threads running the docker initialization + routines simultaneously as this can cause issues with the docker API. + """ + if gcc_bin.startswith("docker:"): + with _GCC_ENV_DOCKER_CONSTRUCTOR_LOCK: + return GccEnv(*args, gcc_bin=gcc_bin, **kwargs) + else: + return GccEnv(*args, gcc_bin=gcc_bin, **kwargs) diff --git a/compiler_gym/envs/llvm/datasets/cbench.py b/compiler_gym/envs/llvm/datasets/cbench.py index 9cd97c2df..fe979f1e5 100644 --- a/compiler_gym/envs/llvm/datasets/cbench.py +++ b/compiler_gym/envs/llvm/datasets/cbench.py @@ -234,6 +234,10 @@ def download_cBench_runtime_data() -> bool: if (cbench_data / "unpacked").is_file(): return False else: + logger.warning( + "Installing the cBench runtime inputs. This may take a few moments ..." + ) + # Clean up any partially-extracted data directory. if cbench_data.is_dir(): shutil.rmtree(cbench_data) diff --git a/compiler_gym/requirements.txt b/compiler_gym/requirements.txt index b2f06105f..45d51cfb4 100644 --- a/compiler_gym/requirements.txt +++ b/compiler_gym/requirements.txt @@ -2,6 +2,7 @@ absl-py>=0.10.0 deprecated>=1.2.12 docker>=4.0.0 fasteners>=0.15 +frozendict>=1.0.0 grpcio>=1.32.0,<1.44.0 gym>=0.18.0,<0.21 humanize>=2.6.0 diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD index 194882571..1ce2c4cd7 100644 --- a/compiler_gym/service/BUILD +++ b/compiler_gym/service/BUILD @@ -9,14 +9,15 @@ py_library( name = "service", srcs = [ "__init__.py", + "connection_pool.py", ], visibility = ["//visibility:public"], deps = [ - ":compilation_session", - ":connection", # TODO(github.com/facebookresearch/CompilerGym/pull/633): - # add this after circular dependencies are resolved + # add this after circular dependencies are resolved: # ":client_service_compiler_env", + ":compilation_session", + ":connection", ":service_cache", "//compiler_gym/errors", "//compiler_gym/service/proto", diff --git a/compiler_gym/service/CMakeLists.txt b/compiler_gym/service/CMakeLists.txt index 3e129fe76..eb6cf5d26 100644 --- a/compiler_gym/service/CMakeLists.txt +++ b/compiler_gym/service/CMakeLists.txt @@ -10,12 +10,13 @@ cg_py_library( service SRCS "__init__.py" + "connection_pool.py" DEPS + # TODO(github.com/facebookresearch/CompilerGym/pull/633): + # add this after circular dependencies are resolved: + # ::client_service_compiler_env ::compilation_session ::connection - # TODO(github.com/facebookresearch/CompilerGym/pull/633): - # add this after circular dependencies are resolved - #::client_service_compiler_env ::service_cache compiler_gym::errors::errors compiler_gym::service::proto::proto diff --git a/compiler_gym/service/__init__.py b/compiler_gym/service/__init__.py index ddd10fc14..98074b0d1 100644 --- a/compiler_gym/service/__init__.py +++ b/compiler_gym/service/__init__.py @@ -14,12 +14,14 @@ ServiceTransportError, SessionNotFound, ) +from compiler_gym.service.connection_pool import ServiceConnectionPool __all__ = [ - "CompilerGymServiceConnection", "CompilationSession", + "CompilerGymServiceConnection", "ConnectionOpts", "EnvironmentNotSupported", + "ServiceConnectionPool", "ServiceError", "ServiceInitError", "ServiceIsClosed", diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index d46ea0527..2c96e7a82 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -31,6 +31,10 @@ ValidationError, ) from compiler_gym.service import CompilerGymServiceConnection, ConnectionOpts +from compiler_gym.service.connection_pool import ( + ServiceConnectionPool, + ServiceConnectionPoolBase, +) from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest from compiler_gym.service.proto import Benchmark as BenchmarkProto from compiler_gym.service.proto import ( @@ -138,6 +142,7 @@ def __init__( service_message_converters: ServiceMessageConverters = None, connection_settings: Optional[ConnectionOpts] = None, service_connection: Optional[CompilerGymServiceConnection] = None, + service_pool: Optional[ServiceConnectionPool] = None, logger: Optional[logging.Logger] = None, ): """Construct and initialize a CompilerGym environment. @@ -167,7 +172,7 @@ def __init__( `. If not provided, :func:`step()` returns :code:`None` for the observation value. Can be set later using :meth:`env.observation_space - `. For available + `. For available spaces, see :class:`env.observation.spaces `. @@ -176,7 +181,7 @@ def __init__( `. If not provided, :func:`step()` returns :code:`None` for the reward value. Can be set later using :meth:`env.reward_space - `. For available spaces, + `. For available spaces, see :class:`env.reward.spaces `. :param action_space: The name of the action space to use. If not @@ -186,7 +191,8 @@ def __init__( passed to :meth:`env.observation.add_derived_space() `. - :param service_message_converters: Custom converters for action spaces and actions. + :param service_message_converters: Custom converters for action spaces + and actions. :param connection_settings: The settings used to establish a connection with the remote service. @@ -194,6 +200,10 @@ def __init__( :param service_connection: An existing compiler gym service connection to use. + :param service_pool: A service pool to use for acquiring a service + connection. If not specified, the :meth:`global service pool + ` is used. + :raises FileNotFoundError: If service is a path to a file that is not found. @@ -204,9 +214,10 @@ def __init__( # in release 0.2.3. if logger: warnings.warn( - "The `logger` argument is deprecated on ClientServiceCompilerEnv.__init__() " - "and will be removed in a future release. All ClientServiceCompilerEnv " - "instances share a logger named compiler_gym.service.client_service_compiler_env", + "The `logger` argument is deprecated on " + "ClientServiceCompilerEnv.__init__() and will be removed in a " + "future release. All ClientServiceCompilerEnv instances share " + "a logger named compiler_gym.envs.compiler_env", DeprecationWarning, ) @@ -219,10 +230,18 @@ def __init__( self._service_endpoint: Union[str, Path] = service self._connection_settings = connection_settings or ConnectionOpts() - self.service = service_connection or CompilerGymServiceConnection( - endpoint=self._service_endpoint, - opts=self._connection_settings, - ) + if service_connection is None: + self._service_pool: Optional[ServiceConnectionPoolBase] = ( + ServiceConnectionPool.get() if service_pool is None else service_pool + ) + self.service = self._service_pool.acquire( + endpoint=self._service_endpoint, + opts=self._connection_settings, + ) + else: + self._service_pool: Optional[ServiceConnectionPoolBase] = service_pool + self.service = service_connection + self._datasets = Datasets(datasets or []) self.action_space_name = action_space @@ -266,9 +285,9 @@ def __init__( self._benchmark_in_use = self._next_benchmark except StopIteration: # StopIteration raised on next(self.datasets.benchmarks()) if there - # are no benchmarks available. This is to allow ClientServiceCompilerEnv to be - # used without any datasets by setting a benchmark before/during the - # first reset() call. + # are no benchmarks available. This is to allow + # ClientServiceCompilerEnv to be used without any datasets by + # setting a benchmark before/during the first reset() call. pass self.service_message_converters = ( @@ -302,13 +321,13 @@ def __init__( # Mutable state initialized in reset(). self._reward_range: Tuple[float, float] = (-np.inf, np.inf) - self.episode_reward = None + self.episode_reward: Optional[float] = None self.episode_start_time: float = time() self._actions: List[ActionType] = [] # Initialize the default observation/reward spaces. - self.observation_space_spec = None - self.reward_space_spec = None + self.observation_space_spec: Optional[ObservationSpaceSpec] = None + self.reward_space_spec: Optional[Reward] = None self.observation_space = observation_space self.reward_space = reward_space @@ -392,6 +411,8 @@ def commandline(self) -> str: """Calling this method on a :class:`ClientServiceCompilerEnv ` instance raises :code:`NotImplementedError`. + + :return: A string commandline invocation. """ raise NotImplementedError("abstract method") @@ -542,6 +563,7 @@ def _init_kwargs(self) -> Dict[str, Any]: "benchmark": self.benchmark, "connection_settings": self._connection_settings, "service": self._service_endpoint, + "service_pool": self._service_pool, } def fork(self) -> "ClientServiceCompilerEnv": @@ -601,7 +623,7 @@ def fork(self) -> "ClientServiceCompilerEnv": # Copy over the mutable episode state. new_env.episode_reward = self.episode_reward new_env.episode_start_time = self.episode_start_time - new_env._actions = self.actions.copy() + new_env._actions = self.actions.copy() # pylint: disable=protected-access return new_env @@ -687,7 +709,7 @@ def _retry(error) -> Optional[ObservationType]: ) log_severity("%s during reset(): %s", type(error).__name__, error) - if self.service: + if self.service is not None: try: self.service.close() except ServiceError as e: @@ -699,6 +721,7 @@ def _retry(error) -> Optional[ObservationType]: e, type(e).__name__, ) + self.service = None if retry_count >= self._connection_settings.init_max_attempts: @@ -734,8 +757,15 @@ def _call_with_error( # Start a new service if required. if self.service is None: - self.service = CompilerGymServiceConnection( - self._service_endpoint, self._connection_settings + self.service = ( + CompilerGymServiceConnection( + self._service_endpoint, self._connection_settings + ) + if self._service_pool is None + else self._service_pool.acquire( + endpoint=self._service_endpoint, + opts=self._connection_settings, + ) ) self.action_space_name = action_space or self.action_space_name @@ -810,7 +840,7 @@ def _call_with_error( self.observation.session_id = reply.session_id self.reward.get_cost = self.observation.__getitem__ self.episode_start_time = time() - self._actions = [] + self._actions: List[ActionType] = [] # If the action space has changed, update it. if reply.HasField("new_action_space"): @@ -852,15 +882,17 @@ def raw_step( and rewards are lists. :raises SessionNotFound: If :meth:`reset() - ` has not been called. + ` has not been + called. .. warning:: Don't call this method directly, use :meth:`step() - ` or :meth:`multistep() + ` or + :meth:`multistep() ` instead. The - :meth:`raw_step() ` method is an - implementation detail. + :meth:`raw_step() ` + method is an implementation detail. """ if not self.in_episode: raise SessionNotFound("Must call reset() before step()") @@ -1258,8 +1290,9 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult def send_param(self, key: str, value: str) -> str: """Send a single parameter to the compiler service. - See :meth:`send_params() ` - for more information. + See :meth:`send_params() + ` for more + information. :param key: The parameter key. diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index e8fd9778d..38b760b57 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -10,11 +10,12 @@ from pathlib import Path from signal import Signals from time import sleep, time -from typing import Dict, Iterable, List, Optional, TypeVar, Union +from typing import Dict, FrozenSet, Iterable, List, Optional, TypeVar, Union import grpc from deprecated.sphinx import deprecated -from pydantic import BaseModel +from frozendict import frozendict +from pydantic import BaseModel, root_validator import compiler_gym.errors from compiler_gym.service.proto import ( @@ -40,8 +41,8 @@ # Spurious error UNAVAILABLE "Trying to connect an http1.x server". # https://putridparrot.com/blog/the-unavailable-trying-to-connect-an-http1-x-server-grpc-error/ ("grpc.enable_http_proxy", 0), - # Disable TCP port re-use to mitigate port conflict errors when starting - # many services in parallel. Context: + # Disable TCP port reuse to mitigate port conflict errors when starting many + # services in parallel. Context: # https://github.com/facebookresearch/CompilerGym/issues/572 ("grpc.so_reuseport", 0), ] @@ -49,7 +50,14 @@ logger = logging.getLogger(__name__) -class ConnectionOpts(BaseModel): +class HashableBaseModel(BaseModel): + """A pydantic model that is hashable. Requires that all fields are hashable.""" + + def __hash__(self): + return hash((type(self),) + tuple(self.__dict__.values())) + + +class ConnectionOpts(HashableBaseModel): """The options used to configure a connection to a service.""" rpc_call_max_seconds: float = 300 @@ -91,19 +99,25 @@ class ConnectionOpts(BaseModel): always_send_benchmark_on_reset: bool = False """Send the full benchmark program data to the compiler service on ever call to :meth:`env.reset() `. This is more - efficient in cases where the majority of calls to - :meth:`env.reset() ` uses a different - benchmark. In case of benchmark re-use, leave this :code:`False`. + efficient in cases where the majority of calls to :meth:`env.reset() + ` uses a different benchmark. In case + of benchmark reuse, leave this :code:`False`. """ - script_args: List[str] = [] + script_args: FrozenSet[str] = frozenset([]) """If the service is started from a local script, this set of args is used on the command line. No effect when used for existing sockets.""" - script_env: Dict[str, str] = {} + script_env: Dict[str, str] = frozendict({}) """If the service is started from a local script, this set of env vars is used on the command line. No effect when used for existing sockets.""" + @root_validator + def freeze_types(cls, values): + values["script_args"] = frozenset(values["script_args"]) + values["script_env"] = frozendict(values["script_env"]) + return values + # Deprecated since v0.2.4. # This type is for backwards compatibility that will be removed in a future release. @@ -301,7 +315,7 @@ def __init__( port_init_max_seconds: float, rpc_init_max_seconds: float, process_exit_max_seconds: float, - script_args: List[str], + script_args: FrozenSet[str], script_env: Dict[str, str], ): """Constructor. @@ -323,7 +337,7 @@ def __init__( f"--working_dir={self.cache.path}", ] # Add any custom arguments - cmd += script_args + cmd += list(script_args) # Set the root of the runfiles directory. env = os.environ.copy() @@ -532,10 +546,8 @@ def close(self): def __repr__(self): if self.process.poll() is None: - return ( - f"Connection to service at {self.url} running on PID {self.process.pid}" - ) - return f"Connection to dead service at {self.url}" + return f"ManagedConnection({self.url}, pid={self.process.pid})" + return f"ManagedConnection({self.url}, not running)" class UnmanagedConnection(Connection): @@ -575,25 +587,26 @@ def __init__(self, url: str, rpc_init_max_seconds: float): super().__init__(channel, url) def __repr__(self): - return f"Connection to unmanaged service {self.url}" + return f"UnmanagedConnection({self.url})" class CompilerGymServiceConnection: """A connection to a compiler gym service. There are two types of service connections: managed and unmanaged. The type - of connection is determined by the endpoint. If a "host:port" URL is provided, - an unmanaged connection is created. If the path of a file is provided, a - managed connection is used. The difference between a managed and unmanaged - connection is that with a managed connection, the lifecycle of the service - if controlled by the client connection. That is, when a managed connection - is created, a service subprocess is started by executing the specified path. - When the connection is closed, the subprocess is terminated. With an - unmanaged connection, if the service fails is goes offline, the client will - fail. - - This class provides a common abstraction between the two types of connection, - and provides a call method for invoking remote procedures on the service. + of connection is determined by the endpoint. If a "host:port" URL is + provided, an unmanaged connection is created. If the path of a file is + provided, a managed connection is used. The difference between a managed and + unmanaged connection is that with a managed connection, the lifecycle of the + service if controlled by the client connection. That is, when a managed + connection is created, a service subprocess is started by executing the + specified path. When the connection is closed, the subprocess is terminated. + With an unmanaged connection, if the service fails is goes offline, the + client will fail. + + This class provides a common abstraction between the two types of + connection, and provides a call method for invoking remote procedures on the + service. Example usage of an unmanaged service connection: @@ -622,7 +635,9 @@ class CompilerGymServiceConnection: :ivar stub: A CompilerGymServiceStub that can be used as the first argument to :py:meth:`__call__()` to specify an RPC method to call. + :ivar action_spaces: A list of action spaces provided by the service. + :ivar observation_spaces: A list of observation spaces provided by the service. """ @@ -630,20 +645,40 @@ class CompilerGymServiceConnection: def __init__( self, endpoint: Union[str, Path], - opts: ConnectionOpts = None, + opts: ConnectionOpts, + owning_service_pool: Optional["ServiceConnectionPool"] = None, # noqa: F821 ): """Constructor. + .. note:: + + Starting new services is expensive. Consider using the + :class:`ServiceConnectionPool + ` class to manage + services rather than constructing them yourself. + :param endpoint: The connection endpoint. Either the URL of a service, e.g. "localhost:8080", or the path of a local service binary. + :param opts: The connection options. + + :param owning_service_pool: A backref to the owning + :class:`ServiceConnectionPool + `, if this service is + managed by one. + :raises ValueError: If the provided options are invalid. - :raises FileNotFoundError: In case opts.local_service_binary is not found. + + :raises FileNotFoundError: In case opts.local_service_binary is not + found. + :raises TimeoutError: In case the service failed to start within opts.init_max_seconds seconds. """ + self.released = False self.endpoint = endpoint self.opts = opts or ConnectionOpts() + self.owning_service_pool = owning_service_pool self.connection = None self.stub = None self._establish_connection() @@ -726,21 +761,73 @@ def _create_connection( ) def __repr__(self): - if self.connection is None: - return f"Closed connection to {self.endpoint}" - return str(self.endpoint) + return f"CompilerGymServiceConnection({self.connection or 'detached'})" @property def closed(self) -> bool: """Whether the connection is closed.""" - return self.connection is None + # Defensive hasattr() because this property is accessed by destructor. + if hasattr(self, "connection"): + return self.connection is None + return True + + def acquire(self) -> "CompilerGymServiceConnection": + """Mark this connection as in-use.""" + if not self.released: + raise TypeError( + "Attempting to acquire a connection that is already acquired." + ) + self.released = False + return self - def close(self): + def release(self) -> None: + """Mark this connection as not in-use.""" + if not self.released: + self.owning_service_pool.release(self) + self.released = True + + def shutdown(self): + """Shut down the connection. + + Once a connection has been shutdown, it cannot be reused. + """ if self.closed: return - self.connection.close() + + try: + self.connection.close() + except ServiceError as e: + # close() can raise ServiceError if the service exists with a + # non-zero return code. We swallow the error here as we are + # disposing of the service. + logger.debug( + "Ignoring service error during shutdown attempt: %s (%s)", + e, + type(e).__name__, + ) self.connection = None + def close(self): + """Mark this connection as closed. + + If the service is managed by a :class:`ServiceConnectionPool + `, this will indicate to the + pool that the connection is ready to be reused. If the service is not + managed by a pool, this will shut it down. + """ + if self.owned_by_service_pool: + self.release() + else: + self.shutdown() + + @property + def owned_by_service_pool(self): + # Defensive hasattr() test because this property is accessed by the + # destructor, where the object could be in a partially initialized + # state. + if hasattr(self, "owning_service_pool"): + return self.owning_service_pool is not None + def __del__(self): # Don't let the subprocess be orphaned if user forgot to close(), or # if an exception was thrown. @@ -825,3 +912,12 @@ def __call__( retry_wait_backoff_exponent or self.opts.retry_wait_backoff_exponent ), ) + + def __enter__(self) -> "CompilerGymServiceConnection": + """Support for 'with' statements.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Support for 'with' statements.""" + self.close() + return False diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py new file mode 100644 index 000000000..96ac5642b --- /dev/null +++ b/compiler_gym/service/connection_pool.py @@ -0,0 +1,219 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""This module contains a reusable pool of service connections.""" +import atexit +import logging +from collections import defaultdict +from pathlib import Path +from threading import Lock +from typing import Dict, List, Set, Tuple + +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts + +logger = logging.getLogger(__name__) + +# We identify connections by the binary path and set of connection opts. +ServiceConnectionCacheKey = Tuple[Path, ConnectionOpts] + + +class ServiceConnectionPoolBase: + """A class that provides the base interface for service connection pools.""" + + def acquire( + self, endpoint: Path, opts: ConnectionOpts + ) -> CompilerGymServiceConnection: + return CompilerGymServiceConnection( + endpoint=endpoint, opts=opts, owning_service_pool=self + ) + + def release(self, service: CompilerGymServiceConnection) -> None: + pass + + +class ServiceConnectionPool(ServiceConnectionPoolBase): + """An object pool for compiler service connections. + + This class implements a thread-safe pool for compiler service connections. + This enables compiler service connections to be reused, avoiding the + expensive initialization of a new service. + + There is a global instance of this class, available via the static + :meth:`ServiceConnectionPool.get() + ` method. To use the pool, + acquire a reference to the global instance, and call the + :meth:`ServiceConnectionPool.acquire() + ` method to construct and + return service connections: + + >>> pool = ServiceConnectionPool.get() + >>> with pool.acquire(Path("/path/to/service"), ConnectionOpts()) as service: + ... # Do something with the service. + + When a service is closed (by calling :meth:`service.close() + `), it is + automatically released back to the pool so that a future request for the + same type of service will reuse the connection. + + :ivar pool: A pool of service connections that are ready for use. + + :vartype pool: Dict[Tuple[Path, ConnectionOpts], + List[CompilerGymServiceConnection]] + + :ivar allocated: The set of service connections that are currently in use. + + :vartype allocated: Set[CompilerGymServiceConnection] + """ + + def __init__(self): + """""" + self._lock = Lock() + self.pool: Dict[ + ServiceConnectionCacheKey, List[CompilerGymServiceConnection] + ] = defaultdict(list) + self.allocated: Set[CompilerGymServiceConnection] = set() + + # Add a flag to indicate a closed connection pool because of + # out-of-order execution of destructors and the atexit callback. + self.closed = False + + atexit.register(self.close) + + def acquire( + self, endpoint: Path, opts: ConnectionOpts + ) -> CompilerGymServiceConnection: + """Acquire a service connection from the pool. + + If an existing connection is available in the pool, it is returned. + Otherwise, a new connection is created. + """ + key: ServiceConnectionCacheKey = (endpoint, opts) + with self._lock: + if self.closed: + # This should never happen. + raise TypeError("ServiceConnectionPool is closed") + + if self.pool[key]: + service = self.pool[key].pop().acquire() + logger.debug( + "Reusing %s, %d environments remaining in pool", + service.connection.url, + len(self.pool[key]), + ) + else: + # No free service connections, construct a new one. + service = CompilerGymServiceConnection( + endpoint=endpoint, opts=opts, owning_service_pool=self + ) + logger.debug("Created %s", service.connection.url) + + self.allocated.add(service) + + return service + + def release(self, service: CompilerGymServiceConnection) -> None: + """Release a service connection back to the pool. + + .. note:: + + This method is called automatically by the :meth:`service.close() + ` method of + acquired service connections. You do not have to call this method + yourself. + """ + key: ServiceConnectionCacheKey = (service.endpoint, service.opts) + with self._lock: + # During shutdown, the shutdown routine for this + # ServiceConnectionPool may be called before the destructor of + # the managed CompilerGymServiceConnection objects. + if self.closed: + return + + if service not in self.allocated: + logger.debug("Discarding service that does not belong to pool") + return + + self.allocated.remove(service) + + # Only managed processes have a process attribute. + if hasattr(service.connection, "process"): + # A dead service cannot be reused, discard it. + if service.closed or service.connection.process.poll() is not None: + logger.debug("Discarding service with dead process") + return + # A service that has been shutdown cannot be reused, discard it. + if not service.connection: + logger.debug("Discarding service that has no connection") + return + + self.pool[key].append(service) + + logger.debug("Released %s, pool size %d", service.connection.url, self.size) + + def __contains__(self, service: CompilerGymServiceConnection): + """Check if a service connection is managed by the pool.""" + key: ServiceConnectionCacheKey = (service.endpoint, service.opts) + return service in self.allocated or service in self.pool[key] + + @property + def size(self): + """Return the total number of connections in the pool.""" + return sum(len(x) for x in self.pool.values()) + len(self.allocated) + + def __len__(self): + return self.size + + def close(self) -> None: + """Close the pool, terminating all connections. + + Once closed, the pool cannot be used again. It is safe to call this + method more than once. + """ + with self._lock: + if self.closed: + return + + try: + logger.debug( + "Closing the service connection pool with %d cached and %d live connections", + self.size, + len(self.allocated), + ) + except ValueError: + # As this method is invoked by the atexit callback, the logger + # may already have closed its streams, in which case a + # ValueError is raised. + pass + + for connections in self.pool.values(): + for connection in connections: + connection.shutdown() + self.pool = defaultdict(list) + for connection in self.allocated: + connection.shutdown() + self.allocated = set() + self.closed = True + + def __del__(self) -> None: + self.close() + + def __enter__(self) -> "ServiceConnectionPool": + """Support for "with" statement.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Support for "with" statement.""" + self.close() + return False + + @staticmethod + def get() -> "ServiceConnectionPool": + """Return the global instance of the service connection pool.""" + return _SERVICE_CONNECTION_POOL + + def __repr__(self) -> str: + return f"{type(self).__name__}(size={self.size})" + + +_SERVICE_CONNECTION_POOL = ServiceConnectionPool() diff --git a/compiler_gym/util/flags/README.md b/compiler_gym/util/flags/README.md index 38180ffd9..5751b87dd 100644 --- a/compiler_gym/util/flags/README.md +++ b/compiler_gym/util/flags/README.md @@ -2,8 +2,8 @@ This directory contains modules that define command line flags for use by `compiler_gym.bin` and other scripts. The reason for defining flags here is to -allow flag names to be re-used across scripts without causing -multiple-definition errors when the scripts are imported. +allow flag names to be reused across scripts without causing multiple-definition +errors when the scripts are imported. Using these flags requires that the absl flags library is initialized. As such they should not be used in the core library. diff --git a/compiler_gym/wrappers/commandline.py b/compiler_gym/wrappers/commandline.py index 976b339e5..79dbcab25 100644 --- a/compiler_gym/wrappers/commandline.py +++ b/compiler_gym/wrappers/commandline.py @@ -143,7 +143,7 @@ def __init__( flag=env.action_space.flags[a], description=env.action_space.descriptions[a], ) - for a in (env.action_space.flags.index(f) for f in flags) + for a in (env.action_space[f] for f in flags) ], name=f"{type(self).__name__}<{name or env.action_space.name}, {len(flags)}>", ) diff --git a/docs/source/compiler_gym/service.rst b/docs/source/compiler_gym/service.rst index 2f474dc1c..5e0033b4d 100644 --- a/docs/source/compiler_gym/service.rst +++ b/docs/source/compiler_gym/service.rst @@ -38,7 +38,7 @@ ClientServiceCompilerEnv .. automethod:: __init__ -The connection object +The Connection Object --------------------- .. autoclass:: CompilerGymServiceConnection @@ -47,7 +47,8 @@ The connection object .. automethod:: __init__ .. automethod:: __call__ -Configuring the connection + +Configuring the Connection -------------------------- The :class:`ConnectionOpts ` object is used @@ -57,6 +58,13 @@ to configure the options used for managing a service connection. :members: +Re-using Connections +-------------------- + +.. autoclass:: ServiceConnectionPool + :members: + + Exceptions ---------- diff --git a/docs/source/envs/gcc.rst b/docs/source/envs/gcc.rst index 93efd42cb..45dd8bf97 100644 --- a/docs/source/envs/gcc.rst +++ b/docs/source/envs/gcc.rst @@ -42,9 +42,11 @@ On Linux, install Docker using: "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu \ $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null sudo apt-get update && sudo apt-get install docker-ce docker-ce-cli containerd.io + sudo usermod -aG docker $USER + su - $USER See the `official documentation `_ for -alternative installation options. +more details and alternative installation options. On both Linux and macOS, use the following command to check if Docker is working: diff --git a/examples/gcc_autotuning/info.py b/examples/gcc_autotuning/info.py index 9b56b94a4..72ebc38c5 100644 --- a/examples/gcc_autotuning/info.py +++ b/examples/gcc_autotuning/info.py @@ -52,6 +52,7 @@ def info( if not dfs: print("No results") + return df = pd.concat(dfs) df = df.groupby(["timestamp", "search"])[["scaled_size"]].agg(geometric_mean) diff --git a/examples/gcc_autotuning/tune.py b/examples/gcc_autotuning/tune.py index 5a702cbdc..efa90f05c 100644 --- a/examples/gcc_autotuning/tune.py +++ b/examples/gcc_autotuning/tune.py @@ -5,7 +5,6 @@ """Autotuning script for GCC command line options.""" import random from itertools import islice, product -from multiprocessing import Lock from pathlib import Path from typing import NamedTuple @@ -64,10 +63,6 @@ "objective", "obj_size", ["asm_size", "obj_size"], "Which objective to use" ) -# Lock to prevent multiple processes all calling compiler_gym.make("gcc-v0") -# simultaneously as this can cause issues with the docker API. -GCC_ENV_CONSTRUCTOR_LOCK = Lock() - def random_search(env: CompilerEnv): best = float("inf") @@ -160,10 +155,7 @@ def scaled_best(self) -> float: def run_search(search: str, benchmark: str, seed: int) -> SearchResult: """Run a search and return the search class instance.""" - with GCC_ENV_CONSTRUCTOR_LOCK: - env = compiler_gym.make("gcc-v0", gcc_bin=FLAGS.gcc_bin) - - try: + with compiler_gym.make("gcc-v0", gcc_bin=FLAGS.gcc_bin) as env: random.seed(seed) np.random.seed(seed) @@ -172,8 +164,6 @@ def run_search(search: str, benchmark: str, seed: int) -> SearchResult: baseline_size = objective(env) env.reset(benchmark=benchmark) best_size = _SEARCH_FUNCTIONS[search](env) - finally: - env.close() return SearchResult( search=search, diff --git a/examples/gcc_autotuning/tune_test.py b/examples/gcc_autotuning/tune_test.py index d6a00a480..cadab187b 100644 --- a/examples/gcc_autotuning/tune_test.py +++ b/examples/gcc_autotuning/tune_test.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import shutil import subprocess import sys from functools import lru_cache @@ -24,33 +25,34 @@ def docker_is_available() -> bool: return False -@lru_cache(maxsize=2) -def system_gcc_is_available() -> bool: +def system_has_functional_gcc(gcc_path: str) -> bool: """Return whether there is a system GCC available.""" try: stdout = subprocess.check_output( - ["gcc", "--version"], universal_newlines=True, stderr=subprocess.DEVNULL + [gcc_path, "--version"], + universal_newlines=True, + stderr=subprocess.DEVNULL, + timeout=30, ) # On some systems "gcc" may alias to a different compiler, so check for # the presence of the name "gcc" in the first line of output. return "gcc" in stdout.split("\n")[0].lower() - except (subprocess.CalledProcessError, FileNotFoundError): + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): return False -def system_gcc_path() -> str: - """Return the path of the system GCC as a string.""" - return subprocess.check_output( - ["which", "gcc"], universal_newlines=True, stderr=subprocess.DEVNULL - ).strip() - - +@lru_cache(maxsize=1) def gcc_bins() -> Iterable[str]: """Return a list of available GCCs.""" if docker_is_available(): yield "docker:gcc:11.2.0" - if system_gcc_is_available(): - yield system_gcc_path() + system_gcc = shutil.which("gcc") + if system_gcc and system_has_functional_gcc(system_gcc): + yield system_gcc @pytest.fixture(scope="module", params=gcc_bins()) @@ -58,6 +60,7 @@ def gcc_bin(request) -> str: return request.param +@pytest.mark.timeout(600) @pytest.mark.parametrize("search", ["random", "hillclimb", "genetic"]) def test_tune_smoke_test(search: str, gcc_bin: str, capsys, tmpdir: Path): tmpdir = Path(tmpdir) diff --git a/examples/tabular_q.py b/examples/tabular_q.py index da7419c25..94124af93 100644 --- a/examples/tabular_q.py +++ b/examples/tabular_q.py @@ -120,7 +120,7 @@ def rollout(qtable, env, printout=False): for i in range(FLAGS.episode_length): a = select_action(qtable, observation, i) action_seq.append(a) - observation, reward, done, info = env.step(env.action_space.flags.index(a)) + observation, reward, done, info = env.step(env.action_space[a]) rewards.append(reward) if done: break @@ -146,10 +146,10 @@ def train(q_table, env): hashed = make_q_table_key(observation, a, current_length) if hashed not in q_table: q_table[hashed] = 0 - # Take a stap in the environment, record the reward and state transition. + # Take a step in the environment, record the reward and state transition. # Effectively we are evaluating the policy by taking a step in the # environment. - observation, reward, done, info = env.step(env.action_space.flags.index(a)) + observation, reward, done, info = env.step(env.action_space[a]) if done: break current_length += 1 diff --git a/tests/compiler_env_test.py b/tests/compiler_env_test.py index 0049dbaa9..7d73c157c 100644 --- a/tests/compiler_env_test.py +++ b/tests/compiler_env_test.py @@ -9,7 +9,7 @@ from compiler_gym.envs import llvm from compiler_gym.envs.llvm import LlvmEnv -from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from tests.test_main import main pytest_plugins = ["tests.pytest_plugins.llvm"] @@ -174,7 +174,7 @@ def test_step_session_id_not_found(env: LlvmEnv): @pytest.fixture(scope="function") def remote_env() -> LlvmEnv: """A test fixture that yields a connection to a remote service.""" - service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) + service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY, ConnectionOpts()) try: with LlvmEnv(service=service.connection.url) as env: yield env diff --git a/tests/llvm/BUILD b/tests/llvm/BUILD index d3639abde..be4e75f97 100644 --- a/tests/llvm/BUILD +++ b/tests/llvm/BUILD @@ -110,7 +110,7 @@ py_test( "//compiler_gym/third_party/cbench:crc32", ], deps = [ - "//compiler_gym/envs", + "//compiler_gym", "//tests:test_main", "//tests/pytest_plugins:llvm", ], diff --git a/tests/llvm/CMakeLists.txt b/tests/llvm/CMakeLists.txt index f1c5403ff..e47939754 100644 --- a/tests/llvm/CMakeLists.txt +++ b/tests/llvm/CMakeLists.txt @@ -108,7 +108,7 @@ cg_py_test( DATA compiler_gym::third_party::cbench::crc32 DEPS - compiler_gym::envs::envs + compiler_gym::compiler_gym tests::pytest_plugins::llvm tests::test_main ) diff --git a/tests/llvm/action_space_test.py b/tests/llvm/action_space_test.py index 669958052..5a377c0b0 100644 --- a/tests/llvm/action_space_test.py +++ b/tests/llvm/action_space_test.py @@ -17,12 +17,12 @@ def test_commandline_no_actions(env: LlvmEnv): def test_commandline(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) - env.step(env.action_space.flags.index("-reg2mem")) + env.step(env.action_space["-mem2reg"]) + env.step(env.action_space["-reg2mem"]) assert env.commandline() == "opt -mem2reg -reg2mem input.bc -o output.bc" assert env.commandline_to_actions(env.commandline()) == [ - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ] diff --git a/tests/llvm/custom_benchmarks_test.py b/tests/llvm/custom_benchmarks_test.py index e7372710e..0339c9e58 100644 --- a/tests/llvm/custom_benchmarks_test.py +++ b/tests/llvm/custom_benchmarks_test.py @@ -39,11 +39,12 @@ def test_reset_invalid_benchmark(env: LlvmEnv): def test_invalid_benchmark_data(env: LlvmEnv): benchmark = Benchmark.from_file_contents( - "benchmark://new", "Invalid bitcode".encode("utf-8") + "benchmark://test_invalid_benchmark_data", "Invalid bitcode".encode("utf-8") ) with pytest.raises( - ValueError, match='Failed to parse LLVM bitcode: "benchmark://new"' + ValueError, + match='Failed to parse LLVM bitcode: "benchmark://test_invalid_benchmark_data"', ): env.reset(benchmark=benchmark) @@ -51,11 +52,11 @@ def test_invalid_benchmark_data(env: LlvmEnv): def test_invalid_benchmark_missing_file(env: LlvmEnv): benchmark = Benchmark( BenchmarkProto( - uri="benchmark://new", + uri="benchmark://test_invalid_benchmark_missing_file", ) ) - with pytest.raises(ValueError, match="No program set"): + with pytest.raises(ValueError, match="No program set in Benchmark:"): env.reset(benchmark=benchmark) @@ -64,7 +65,9 @@ def test_benchmark_path_empty_file(env: LlvmEnv): tmpdir = Path(tmpdir) (tmpdir / "test.bc").touch() - benchmark = Benchmark.from_file("benchmark://new", tmpdir / "test.bc") + benchmark = Benchmark.from_file( + "benchmark://test_benchmark_path_empty_file", tmpdir / "test.bc" + ) with pytest.raises(ValueError, match="Failed to parse LLVM bitcode"): env.reset(benchmark=benchmark) @@ -76,7 +79,9 @@ def test_invalid_benchmark_path_contents(env: LlvmEnv): with open(str(tmpdir / "test.bc"), "w") as f: f.write("Invalid bitcode") - benchmark = Benchmark.from_file("benchmark://new", tmpdir / "test.bc") + benchmark = Benchmark.from_file( + "benchmark://test_invalid_benchmark_path_contents", tmpdir / "test.bc" + ) with pytest.raises(ValueError, match="Failed to parse LLVM bitcode"): env.reset(benchmark=benchmark) @@ -85,7 +90,8 @@ def test_invalid_benchmark_path_contents(env: LlvmEnv): def test_benchmark_path_invalid_scheme(env: LlvmEnv): benchmark = Benchmark( BenchmarkProto( - uri="benchmark://new", program=File(uri="invalid_scheme://test") + uri="benchmark://test_benchmark_path_invalid_scheme", + program=File(uri="invalid_scheme://test"), ), ) @@ -100,16 +106,20 @@ def test_benchmark_path_invalid_scheme(env: LlvmEnv): def test_custom_benchmark(env: LlvmEnv): - benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE) + benchmark = Benchmark.from_file( + "benchmark://test_custom_benchmark", EXAMPLE_BITCODE_FILE + ) env.reset(benchmark=benchmark) - assert env.benchmark == "benchmark://new" + assert env.benchmark == "benchmark://test_custom_benchmark" def test_custom_benchmark_constructor(): - benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE) + benchmark = Benchmark.from_file( + "benchmark://test_custom_benchmark_constructor", EXAMPLE_BITCODE_FILE + ) with gym.make("llvm-v0", benchmark=benchmark) as env: env.reset() - assert env.benchmark == "benchmark://new" + assert env.benchmark == "benchmark://test_custom_benchmark_constructor" def test_make_benchmark_single_bitcode(env: LlvmEnv): diff --git a/tests/llvm/datasets/cbench_validate_test.py b/tests/llvm/datasets/cbench_validate_test.py index 377f720d3..6b5b974d1 100644 --- a/tests/llvm/datasets/cbench_validate_test.py +++ b/tests/llvm/datasets/cbench_validate_test.py @@ -12,14 +12,14 @@ pytest_plugins = ["tests.pytest_plugins.llvm"] -@pytest.mark.timeout(600) +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str): """Run the validation routine on all benchmarks.""" env.reward_space = "IrInstructionCount" env.reset(benchmark=validatable_cbench_uri) # Run a single step. - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) # Validate the environment state. result: ValidationResult = env.validate() @@ -41,7 +41,7 @@ def test_non_validatable_benchmark_validate( env.reset(benchmark=non_validatable_cbench_uri) # Run a single step. - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) # Validate the environment state. result: ValidationResult = env.validate() diff --git a/tests/llvm/fork_env_test.py b/tests/llvm/fork_env_test.py index 55779a945..8d30f3acd 100644 --- a/tests/llvm/fork_env_test.py +++ b/tests/llvm/fork_env_test.py @@ -8,7 +8,13 @@ import pytest -from compiler_gym.envs import LlvmEnv +import compiler_gym +from compiler_gym.envs.llvm import LLVM_SERVICE_BINARY, LlvmEnv +from compiler_gym.service import ( + CompilerGymServiceConnection, + ConnectionOpts, + ServiceError, +) from compiler_gym.util.runfiles_path import runfiles_path from tests.test_main import main @@ -31,70 +37,78 @@ def test_with_statement(env: LlvmEnv): assert env.in_episode -def test_fork_child_process_is_not_orphaned(env: LlvmEnv): - env.reset("cbench-v1/crc32") - with env.fork() as fkd: - # Check that both environments share the same service. - assert isinstance(env.service.connection.process, subprocess.Popen) - assert isinstance(fkd.service.connection.process, subprocess.Popen) +def test_fork_child_process_is_not_orphaned(): + service = CompilerGymServiceConnection(LLVM_SERVICE_BINARY, ConnectionOpts()) - assert env.service.connection.process.pid == fkd.service.connection.process.pid - process = env.service.connection.process + with compiler_gym.make("llvm-v0", service_connection=service) as env: + env.reset("cbench-v1/crc32") + with env.fork() as fkd: + # Check that both environments share the same service. + assert isinstance(env.service.connection.process, subprocess.Popen) + assert isinstance(fkd.service.connection.process, subprocess.Popen) - # Sanity check that both services are alive. - assert not env.service.connection.process.poll() - assert not fkd.service.connection.process.poll() + assert ( + env.service.connection.process.pid == fkd.service.connection.process.pid + ) + process = env.service.connection.process - # Close the parent service. - env.close() + # Sanity check that both services are alive. + assert not env.service.connection.process.poll() + assert not fkd.service.connection.process.poll() - # Check that the service is still alive. - assert not env.service - assert not fkd.service.connection.process.poll() + # Close the parent service. + env.close() - # Close the forked service. - fkd.close() + # Check that the service is still alive. + assert not env.service + assert not fkd.service.connection.process.poll() - # Check that the service has been killed. - assert process.poll() is not None + # Close the forked service. + fkd.close() + # Check that the service has been killed. + assert process.poll() is not None -def test_fork_chain_child_processes_are_not_orphaned(env: LlvmEnv): - env.reset("cbench-v1/crc32") - - # Create a chain of forked environments. - a = env.fork() - b = a.fork() - c = b.fork() - d = c.fork() - try: - # Sanity check that they share the same underlying service. - assert ( - env.service.connection.process - == a.service.connection.process - == b.service.connection.process - == c.service.connection.process - == d.service.connection.process - ) - proc = env.service.connection.process - # Kill the forked environments one by one. - a.close() - assert proc.poll() is None - b.close() - assert proc.poll() is None - c.close() - assert proc.poll() is None - d.close() - assert proc.poll() is None - # Kill the final environment, refcount 0, service is closed. - env.close() - assert proc.poll() is not None - finally: - a.close() - b.close() - c.close() - d.close() +def test_fork_chain_child_processes_are_not_orphaned(env: LlvmEnv): + service = CompilerGymServiceConnection(LLVM_SERVICE_BINARY, ConnectionOpts()) + + with compiler_gym.make("llvm-v0", service_connection=service) as env: + env.reset() + + # Create a chain of forked environments. + a = env.fork() + b = a.fork() + c = b.fork() + d = c.fork() + + try: + # Sanity check that they share the same underlying service. + assert ( + env.service.connection.process + == a.service.connection.process + == b.service.connection.process + == c.service.connection.process + == d.service.connection.process + ) + proc = env.service.connection.process + # Kill the forked environments one by one. + a.close() + assert proc.poll() is None + b.close() + assert proc.poll() is None + c.close() + assert proc.poll() is None + d.close() + assert proc.poll() is None + # Kill the final environment, refcount 0, service is closed. + env.close() + assert proc.poll() is not None + finally: + a.close() + b.close() + c.close() + d.close() def test_fork_before_reset(env: LlvmEnv): @@ -187,7 +201,7 @@ def test_fork_modified_ir_is_the_same(env: LlvmEnv): env.reset("cbench-v1/crc32") # Apply an action that modifies the benchmark. - _, _, done, info = env.step(env.action_space.flags.index("-mem2reg")) + _, _, done, info = env.step(env.action_space["-mem2reg"]) assert not done assert not info["action_had_no_effect"] @@ -195,8 +209,8 @@ def test_fork_modified_ir_is_the_same(env: LlvmEnv): assert "\n".join(env.ir.split("\n")[1:]) == "\n".join(fkd.ir.split("\n")[1:]) # Apply another action. - _, _, done, info = env.step(env.action_space.flags.index("-gvn")) - _, _, done, info = fkd.step(fkd.action_space.flags.index("-gvn")) + _, _, done, info = env.step(env.action_space["-gvn"]) + _, _, done, info = fkd.step(fkd.action_space["-gvn"]) assert not done assert not info["action_had_no_effect"] @@ -213,7 +227,10 @@ def test_fork_rewards(env: LlvmEnv, reward_space: str): env.reward_space = reward_space env.reset("cbench-v1/dijkstra") - actions = [env.action_space.flags.index(n) for n in ["-mem2reg", "-simplifycfg"]] + actions = [ + env.action_space["-mem2reg"], + env.action_space["-simplifycfg"], + ] forked = env.fork() try: @@ -231,24 +248,41 @@ def test_fork_previous_cost_reward_update(env: LlvmEnv): env.reward_space = "IrInstructionCount" env.reset("cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) with env.fork() as fkd: - _, a, _, _ = env.step(env.action_space.flags.index("-mem2reg")) - _, b, _, _ = fkd.step(env.action_space.flags.index("-mem2reg")) + _, a, _, _ = env.step(env.action_space["-mem2reg"]) + _, b, _, _ = fkd.step(env.action_space["-mem2reg"]) assert a == b def test_fork_previous_cost_lazy_reward_update(env: LlvmEnv): env.reset("cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) - env.reward["IrInstructionCount"] + env.step(env.action_space["-mem2reg"]) + env.reward["IrInstructionCount"] # noqa with env.fork() as fkd: - env.step(env.action_space.flags.index("-mem2reg")) - fkd.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) + fkd.step(env.action_space["-mem2reg"]) assert env.reward["IrInstructionCount"] == fkd.reward["IrInstructionCount"] +def test_forked_service_dies(env: LlvmEnv): + """Test that if the service dies on a forked environment, each environment + creates new, independent services. + """ + with env.fork() as fkd: + assert env.service == fkd.service + try: + fkd.service.connection.close() + except ServiceError: + pass # shutdown() raises service error if in-episode. + fkd.service.close() + + env.reset() + fkd.reset() + assert env.service != fkd.service + + if __name__ == "__main__": main() diff --git a/tests/llvm/llvm_env_test.py b/tests/llvm/llvm_env_test.py index 8402098c7..3fbf0bcec 100644 --- a/tests/llvm/llvm_env_test.py +++ b/tests/llvm/llvm_env_test.py @@ -20,7 +20,7 @@ from compiler_gym.envs import CompilerEnv, llvm from compiler_gym.envs.llvm.llvm_env import LlvmEnv from compiler_gym.errors import ServiceError -from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from tests.pytest_plugins import llvm as llvm_plugin from tests.test_main import main @@ -34,7 +34,9 @@ def env(request) -> CompilerEnv: with gym.make("llvm-v0") as env: yield env else: - service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) + service = CompilerGymServiceConnection( + llvm.LLVM_SERVICE_BINARY, ConnectionOpts() + ) try: with LlvmEnv(service=service.connection.url) as env: yield env @@ -90,7 +92,7 @@ def test_connection_dies_default_reward(env: LlvmEnv): # with env.reset() above. For UnmanagedConnection, this error will not be # raised. try: - env.service.close() + env.service.shutdown() except ServiceError as e: assert "Service exited with returncode " in str(e) @@ -114,7 +116,7 @@ def test_connection_dies_default_reward_negated(env: LlvmEnv): # with env.reset() above. For UnmanagedConnection, this error will not be # raised. try: - env.service.close() + env.service.shutdown() except ServiceError as e: assert "Service exited with returncode " in str(e) @@ -144,7 +146,7 @@ def test_apply_state(env: LlvmEnv): """Test that apply() on a clean environment produces same state.""" env.reward_space = "IrInstructionCount" env.reset(benchmark="cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) with gym.make("llvm-v0", reward_space="IrInstructionCount") as other: other.apply(env.state) @@ -174,7 +176,7 @@ def test_same_reward_after_reset(env: LlvmEnv): env.reward_space = "IrInstructionCount" env.benchmark = "cbench-v1/dijkstra" - action = env.action_space.flags.index("-instcombine") + action = env.action_space["-instcombine"] env.reset() _, reward_a, _, _ = env.step(action) @@ -201,7 +203,7 @@ def test_ir_sha1(env: LlvmEnv, tmpwd: Path): env.reset(benchmark="cbench-v1/crc32") before = env.ir_sha1 - _, _, done, info = env.step(env.action_space.flags.index("-mem2reg")) + _, _, done, info = env.step(env.action_space["-mem2reg"]) assert not done, info assert not info["action_had_no_effect"], "sanity check failed, action had no effect" @@ -218,8 +220,8 @@ def test_step_multiple_actions_list(env: LlvmEnv): """Pass a list of actions to step().""" env.reset(benchmark="cbench-v1/crc32") actions = [ - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ] _, _, done, _ = env.multistep(actions) assert not done @@ -230,14 +232,14 @@ def test_step_multiple_actions_generator(env: LlvmEnv): """Pass an iterable of actions to step().""" env.reset(benchmark="cbench-v1/crc32") actions = ( - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ) _, _, done, _ = env.multistep(actions) assert not done assert env.actions == [ - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ] diff --git a/tests/llvm/reward_spaces_test.py b/tests/llvm/reward_spaces_test.py index e9214500b..27a7db627 100644 --- a/tests/llvm/reward_spaces_test.py +++ b/tests/llvm/reward_spaces_test.py @@ -23,7 +23,7 @@ def test_instruction_count_reward(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") assert env.observation.IrInstructionCount() == CRC32_INSTRUCTION_COUNT - action = env.action_space.flags.index("-reg2mem") + action = env.action_space["-reg2mem"] env.step(action) assert env.observation.IrInstructionCount() == CRC32_INSTRUCTION_COUNT_AFTER_REG2MEM diff --git a/tests/llvm/service_connection_test.py b/tests/llvm/service_connection_test.py index d7a4f87e2..6a37e40b3 100644 --- a/tests/llvm/service_connection_test.py +++ b/tests/llvm/service_connection_test.py @@ -12,7 +12,7 @@ from compiler_gym.envs.llvm.llvm_env import LlvmEnv from compiler_gym.errors import ServiceError from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv -from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from compiler_gym.third_party.autophase import AUTOPHASE_FEATURE_DIM from tests.test_main import main @@ -27,7 +27,9 @@ def env(request) -> ClientServiceCompilerEnv: with gym.make("llvm-v0") as env: yield env else: - service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) + service = CompilerGymServiceConnection( + llvm.LLVM_SERVICE_BINARY, ConnectionOpts() + ) try: with LlvmEnv(service=service.connection.url) as env: yield env @@ -45,7 +47,7 @@ def test_service_env_dies_reset(env: ClientServiceCompilerEnv): # with env.reset() above. For UnmanagedConnection, this error will not be # raised. try: - env.service.close() + env.service.shutdown() except ServiceError as e: assert "Service exited with returncode " in str(e) diff --git a/tests/llvm/validate_test.py b/tests/llvm/validate_test.py index dc65b34e5..526c4cfd7 100644 --- a/tests/llvm/validate_test.py +++ b/tests/llvm/validate_test.py @@ -16,6 +16,7 @@ pytest_plugins = ["tests.pytest_plugins.llvm"] +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_no_reward(): state = CompilerEnvState( benchmark="benchmark://cbench-v1/crc32", @@ -30,6 +31,7 @@ def test_validate_state_no_reward(): assert str(result) == "✅ cbench-v1/crc32" +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_with_reward(): state = CompilerEnvState( benchmark="benchmark://cbench-v1/crc32", @@ -46,6 +48,7 @@ def test_validate_state_with_reward(): assert str(result) == "✅ cbench-v1/crc32 0.0000" +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_invalid_reward(): state = CompilerEnvState( benchmark="benchmark://cbench-v1/crc32", @@ -64,6 +67,7 @@ def test_validate_state_invalid_reward(): ) +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_without_state_reward(): """Validating state when state has no reward value.""" state = CompilerEnvState( @@ -102,6 +106,7 @@ def test_validate_state_without_env_reward(): assert not result.reward_validation_failed +@pytest.mark.timeout(900) # Validation can take a long time! def test_no_validation_callback_for_custom_benchmark(env: LlvmEnv): """Test that a custom benchmark has no validation callback.""" with tempfile.TemporaryDirectory() as d: diff --git a/tests/mlir/mlir_env_test.py b/tests/mlir/mlir_env_test.py index 35294d7ad..c03958cfb 100644 --- a/tests/mlir/mlir_env_test.py +++ b/tests/mlir/mlir_env_test.py @@ -12,7 +12,7 @@ import compiler_gym from compiler_gym.envs import CompilerEnv, mlir from compiler_gym.envs.mlir import MlirEnv -from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from compiler_gym.spaces import ( Box, Dict, @@ -36,7 +36,9 @@ def env(request) -> CompilerEnv: with gym.make("mlir-v0") as env: yield env else: - service = CompilerGymServiceConnection(mlir.MLIR_SERVICE_BINARY) + service = CompilerGymServiceConnection( + mlir.MLIR_SERVICE_BINARY, ConnectionOpts() + ) try: with MlirEnv(service=service.connection.url) as env: yield env diff --git a/tests/pytest_plugins/gcc.py b/tests/pytest_plugins/gcc.py index 730364bd7..e419edfc8 100644 --- a/tests/pytest_plugins/gcc.py +++ b/tests/pytest_plugins/gcc.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. """Pytest fixtures for the GCC CompilerGym environments.""" +import shutil import subprocess from functools import lru_cache from typing import Iterable @@ -13,38 +14,44 @@ from tests.pytest_plugins.common import docker_is_available -@lru_cache(maxsize=2) -def system_gcc_is_available() -> bool: +def system_has_functional_gcc(gcc_path: str) -> bool: """Return whether there is a system GCC available.""" try: stdout = subprocess.check_output( - ["gcc", "--version"], universal_newlines=True, stderr=subprocess.DEVNULL + [gcc_path, "--version"], + universal_newlines=True, + stderr=subprocess.DEVNULL, + timeout=30, ) # On some systems "gcc" may alias to a different compiler, so check for # the presence of the name "gcc" in the first line of output. return "gcc" in stdout.split("\n")[0].lower() - except (subprocess.CalledProcessError, FileNotFoundError): + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): return False -def system_gcc_path() -> str: - """Return the path of the system GCC as a string.""" - return subprocess.check_output( - ["which", "gcc"], universal_newlines=True, stderr=subprocess.DEVNULL - ).strip() - - -def gcc_environment_is_supported() -> bool: - """Return whether the requirements for the GCC environment are met.""" - return docker_is_available() or system_gcc_is_available() +@lru_cache +def system_gcc_is_available(): + return system_has_functional_gcc(shutil.which("gcc")) +@lru_cache def gcc_bins() -> Iterable[str]: """Return a list of available GCCs.""" if docker_is_available(): yield "docker:gcc:11.2.0" - if system_gcc_is_available(): - yield system_gcc_path() + system_gcc = shutil.which("gcc") + if system_gcc and system_has_functional_gcc(system_gcc): + yield system_gcc + + +def gcc_environment_is_supported() -> bool: + """Return whether the requirements for the GCC environment are met.""" + return len(list(gcc_bins())) > 0 @pytest.fixture(scope="module", params=gcc_bins()) diff --git a/tests/pytest_plugins/llvm.py b/tests/pytest_plugins/llvm.py index 3eeccf313..b8ebacdf5 100644 --- a/tests/pytest_plugins/llvm.py +++ b/tests/pytest_plugins/llvm.py @@ -93,7 +93,7 @@ def non_validatable_cbench_uri(request) -> str: @pytest.fixture(scope="function") def env() -> LlvmEnv: - """Create an LLVM environment.""" + """Test fixture that yields an environment.""" with gym.make("llvm-v0") as env_: yield env_ diff --git a/tests/service/BUILD b/tests/service/BUILD index c5b115c42..86e6e52ae 100644 --- a/tests/service/BUILD +++ b/tests/service/BUILD @@ -17,12 +17,23 @@ py_test( ], ) +py_test( + name = "connection_pool_test", + srcs = ["connection_pool_test.py"], + deps = [ + "//compiler_gym", + "//compiler_gym/service", + "//tests:test_main", + "//tests/pytest_plugins:llvm", + ], +) + py_test( name = "service_cache_test", timeout = "short", srcs = ["service_cache_test.py"], deps = [ - "//compiler_gym/service:service_cache", + "//compiler_gym/service", "//tests:test_main", ], ) diff --git a/tests/service/CMakeLists.txt b/tests/service/CMakeLists.txt index b0fd0fd3b..22d41faa9 100644 --- a/tests/service/CMakeLists.txt +++ b/tests/service/CMakeLists.txt @@ -15,7 +15,20 @@ if(COMPILER_GYM_ENABLE_LLVM_ENV) compiler_gym::compiler_gym compiler_gym::envs::envs compiler_gym::errors::errors + compiler_gym::service::service compiler_gym::service::service_cache tests::test_main ) + + cg_py_test( + NAME + connection_pool_test + SRCS + "connection_pool_test.py" + DEPS + compiler_gym::errors::errors + compiler_gym::service::service + tests::pytest_plugins::llvm + tests::test_main + ) endif() diff --git a/tests/service/connection_pool_test.py b/tests/service/connection_pool_test.py new file mode 100644 index 000000000..cb6773d39 --- /dev/null +++ b/tests/service/connection_pool_test.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unit tests for compiler_gym/service/connection_pool.py.""" + +import pytest + +import compiler_gym +from compiler_gym.envs.llvm import LLVM_SERVICE_BINARY +from compiler_gym.errors import ServiceError +from compiler_gym.service import ConnectionOpts, ServiceConnectionPool +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.llvm"] + + +@pytest.fixture(scope="function") +def pool() -> ServiceConnectionPool: + with ServiceConnectionPool() as pool_: + yield pool_ + + +def test_service_pool_with_statement(): + with ServiceConnectionPool() as pool: + assert not pool.closed + assert pool.closed + + +def test_service_pool_double_close(pool: ServiceConnectionPool): + assert not pool.closed + pool.close() + assert pool.closed + pool.close() + assert pool.closed + + +def test_service_pool_acquire_release(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + assert service in pool + service.release() + assert service in pool + + +def test_service_pool_contains(pool: ServiceConnectionPool): + with ServiceConnectionPool() as other_pool: + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) as service: + assert service in pool + assert service not in other_pool + assert service not in ServiceConnectionPool.get() + + # Service remains in pool after release. + assert service in pool + + +def test_service_pool_close_frees_service(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + assert not service.closed + pool.close() + assert service.closed + + +def test_service_pool_service_is_not_closed(pool: ServiceConnectionPool): + service = None + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + service.close() + assert not service.closed + + +def test_service_pool_with_service_is_not_closed(pool: ServiceConnectionPool): + service = None + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) as service: + assert not service.closed + assert not service.closed + + +def test_service_pool_with_env_is_not_closed(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + service = env.service + assert not service.closed + assert not service.closed + + +def test_service_pool_fork(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + env.reset() + with env.fork() as fkd: + fkd.reset() + assert env.service == fkd.service + assert not env.service.closed + assert not env.service.closed + + +def test_service_pool_release_service(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + service.close() + # A released service remains alive. + assert not service.closed + + +def test_service_pool_release_dead_service(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + service.shutdown() + assert service.closed + service.close() + # A dead service cannot be reused, discard it. + assert service not in pool + + +def test_service_pool_size(pool: ServiceConnectionPool): + assert pool.size == 0 + assert len(pool) == pool.size + + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()): + assert pool.size == 1 + assert len(pool) == pool.size + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()): + assert pool.size == 2 + assert len(pool) == pool.size + + +def test_service_pool_make_release(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as a: + assert len(pool) == 1 + with compiler_gym.make("llvm-v0", service_pool=pool) as b: + a_service = a.service + b_service = b.service + assert a_service != b_service + assert len(pool) == 2 + + with compiler_gym.make("llvm-v0", service_pool=pool) as c: + c_service = c.service + assert a_service == c_service + assert a_service != b_service + assert pool.size == 2 + + +def test_service_pool_make_release_loop(pool: ServiceConnectionPool): + for _ in range(5): + with compiler_gym.make("llvm-v0", service_pool=pool): + assert pool.size == 1 + assert pool.size == 1 + + +def test_service_pool_environment_restarts_service(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + old_service = env.service + env.service.shutdown() + env.service.close() + assert env.service.closed + + # For environment to restart service. + env.reset() + assert not env.service.closed + + new_service = env.service + assert new_service in pool + assert old_service not in pool + + +def test_service_pool_forked_service_dies(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + with env.fork() as fkd: + assert env.service == fkd.service + try: + fkd.service.connection.close() + except ServiceError: + pass # shutdown() raises service error if in-episode. + + env.reset() + fkd.reset() + assert env.service != fkd.service + assert env.service in pool + assert fkd.service in pool + + +def test_service_pool_forked_environment_ends_scope(pool: ServiceConnectionPool): + """Test that the original service does not close when the forked environment + goes out of scope.""" + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + with env.fork() as fkd: + assert env.service == fkd.service + assert not env.service.closed + assert not env.service.closed + + +if __name__ == "__main__": + main() diff --git a/tests/service/connection_test.py b/tests/service/connection_test.py index c51e58463..63e051ba5 100644 --- a/tests/service/connection_test.py +++ b/tests/service/connection_test.py @@ -33,7 +33,7 @@ def dead_connection() -> CompilerGymServiceConnection: def test_create_invalid_options(): with pytest.raises(TypeError, match="No endpoint provided for service connection"): - CompilerGymServiceConnection("") + CompilerGymServiceConnection("", ConnectionOpts()) def test_create_channel_failed_subprocess( @@ -89,16 +89,13 @@ def test_call_stub_negative_timeout(connection: CompilerGymServiceConnection): def test_ManagedConnection_repr(connection: CompilerGymServiceConnection): cnx = connection.connection - assert ( - repr(cnx) - == f"Connection to service at {cnx.url} running on PID {cnx.process.pid}" - ) + assert repr(cnx) == f"ManagedConnection({cnx.url}, pid={cnx.process.pid})" # Kill the service. cnx.process.terminate() cnx.process.communicate() - assert repr(cnx) == f"Connection to dead service at {cnx.url}" + assert repr(cnx) == f"ManagedConnection({cnx.url}, not running)" if __name__ == "__main__":