From 03186eec0fc0deb4903d0b29729c70718a3bdb92 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Tue, 20 Jun 2023 21:15:21 +0000 Subject: [PATCH] Stream client prototype --- .../cirq_google/engine/engine_client.py | 190 +++++++++++++++++- 1 file changed, 188 insertions(+), 2 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index a3c93b72a1ec..3b9590937406 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -18,6 +18,7 @@ import threading from typing import ( AsyncIterable, + AsyncIterator, Awaitable, Callable, Dict, @@ -42,6 +43,8 @@ _M = TypeVar('_M', bound=proto.Message) _R = TypeVar('_R') +JobPath = str +MessageId = str class EngineException(Exception): @@ -95,6 +98,136 @@ def instance(cls): return cls._instance +class ResponseDemux: + """A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern. + Args: + cancel_callback: Function to be called when the future matching its request argument is + canceled. + """ + + def __init__(self): + self._subscribers: Dict[JobPath, Tuple[MessageId, duet.AwaitableFuture]] = {} + self._next_available_message_id = 0 + + def subscribe( + self, request: quantum.QuantumRunStreamRequest + ) -> duet.AwaitableFuture[quantum.QuantumRunStreamResponse]: + """Assumes the message ID has not been set.""" + + if 'create_quantum_program_and_job' in request: + job_path = request.create_quantum_program_and_job.quantum_job.name + elif 'create_quantum_job' in request: + job_path = request.create_quantum_job.quantum_job.name + else: # 'get_quantum_result' in request + job_path = request.get_quantum_result.parent + + request.message_id = self._next_available_message_id + response_future = duet.AwaitableFuture[quantum.QuantumRunStreamResponse]() + self._subscribers[job_path] = (self._next_available_message_id, response_future) + self._next_available_message_id += 1 + return response_future + + def publish(self, response: quantum.QuantumRunStreamResponse) -> None: + if 'error' in response: + job_path = next( + ( + p + for p, (message_id, _) in self._subscribers.items() + if message_id == response.message_id + ), + default='', + ) + elif 'job' in response: + job_path = response.job.name + else: # 'result' in response + job_path = response.result.parent + + if job_path not in self._subscribers: + return + + self._subscribers[job_path].try_set_result(response) + del self._subscribers[job_path] + + def publish_exception(self, exception: GoogleAPICallError) -> None: + """Publishes an exception to all outstanding futures.""" + for _, future in self._subscribers.values(): + future.try_set_exception(exception) + self._subscribers = {} + + +class StreamManager: + def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): + self._grpc_client = grpc_client + self._request_queue = asyncio.Queue() + self._response_demux = ResponseDemux() + self._manage_stream_loop_running = False + + def _executor(self) -> AsyncioExecutor: + # We must re-use a single Executor due to multi-threading issues in gRPC + # clients: https://github.com/grpc/grpc/issues/25364. + return AsyncioExecutor.instance() + + def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: + async def iterator(): + yield await self._request_queue.get() + + # TODO how to make an iterator properly? + return iterator + + async def _manage_stream(self): + """Keeps the stream alive and routes responses to the appropriate request handler""" + while True: + try: + response_iterable = self._grpc_client.quantum_run_stream(self._request_iterator) + async for response in response_iterable: + self._response_demux.publish(response) + except GoogleAPICallError as e: # TODO what's the right error to check here? + self._response_demux.publish_exception(e) + + async def _run_program( + self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + ) -> quantum.QuantumResult: + """This method is executed in a separate asyncio Task for each request.""" + create_program_and_job_request = quantum.QuantumRunStreamRequest( + parent=project_name, + create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( + parent=project_name, quantum_program=program, quantum_job=job + ), + ) + get_result_request = quantum.QuantumRunStreamRequest( + parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name) + ) + + response_future = self._response_demux.subscribe(create_program_and_job_request) + await self._request_queue.put(create_program_and_job_request) + + response: Optional[quantum.QuantumRunStreamResponse] = None + while response is None: + try: + response = await response_future + except GoogleAPICallError: + response_future = self._response_demux.subscribe(get_result_request) + await self._request_queue.put(get_result_request) + + if response.result is not None: + return response.result + # TODO handle QuantumJob response and retryable StreamError. + + async def _cancel(self, job_name: str) -> None: + await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) + + def send( + self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + ) -> duet.AwaitableFuture[quantum.QuantumResult]: + """Sends a request over the stream and returns a future for the result.""" + if not self._manage_stream_loop_running: + self._executor.submit(self._manage_stream) + result_future = self._executor.submit(self._run_program, project_name, program, job) + result_future.add_done_callback(lambda _: self._executor.submit(self._cancel, job.name)) + # TODO will asyncio run_program task terminate when future is cancelled? + return result_future + + class EngineClient: """Client for the Quantum Engine API handling protos and gRPC client. @@ -148,6 +281,10 @@ async def make_client(): return self._executor.submit(make_client).result() + @cached_property + def stream_manager(self) -> StreamManager: + return StreamManager(self.grpc_client) + async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R: """Sends a request by invoking an asyncio callable.""" return await self._run_retry_async(func, request) @@ -277,7 +414,7 @@ async def list_programs_async( val = _date_or_time_to_filter_expr('created_before', created_before) filters.append(f"create_time <= {val}") if has_labels is not None: - for (k, v) in has_labels.items(): + for k, v in has_labels.items(): filters.append(f"labels.{k}:{v}") request = quantum.ListQuantumProgramsRequest( parent=_project_name(project_id), filter=" AND ".join(filters) @@ -528,7 +665,7 @@ async def list_jobs_async( val = _date_or_time_to_filter_expr('created_before', created_before) filters.append(f"create_time <= {val}") if has_labels is not None: - for (k, v) in has_labels.items(): + for k, v in has_labels.items(): filters.append(f"labels.{k}:{v}") if execution_states is not None: state_filter = [] @@ -744,6 +881,55 @@ async def get_job_results_async( get_job_results = duet.sync(get_job_results_async) + def run_job_over_stream( + self, + project_id: str, + program_id: str, + job_id: Optional[str], + processor_ids: Sequence[str], + run_context: any_pb2.Any, + priority: Optional[int] = None, + description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + ) -> duet.AwaitableFuture[quantum.QuantumResult]: + # Check program to run and program parameters. + if priority and not 0 <= priority < 1000: + raise ValueError('priority must be between 0 and 1000') + + # Create job. + job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else '' + job = quantum.QuantumJob( + name=job_name, + scheduling_config=quantum.SchedulingConfig( + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor_names=[ + _processor_name_from_ids(project_id, processor_id) + for processor_id in processor_ids + ] + ) + ), + run_context=run_context, + ) + if priority: + job.scheduling_config.priority = priority + if description: + job.description = description + if labels: + job.labels.update(labels) + job_request = quantum.CreateQuantumJobRequest( + parent=_program_name_from_ids(project_id, program_id), + quantum_job=job, + overwrite_existing_run_context=False, + ) + stream_request = quantum.QuantumRunStreamRequest( + message_id=self._msg_id_generator.generate(), + parent=_project_name(project_id), + create_quantum_job=job_request, + ) + return self.stream_manager.send(stream_request) + + # TODO NEXT UP: change to sending over QuantumProgram instead... + async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the current Engine project. The names of these processors are used to