Skip to content

Commit

Permalink
Stream client prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
verult committed Jun 20, 2023
1 parent d236276 commit 03186ee
Showing 1 changed file with 188 additions and 2 deletions.
190 changes: 188 additions & 2 deletions cirq-google/cirq_google/engine/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import threading
from typing import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Expand All @@ -42,6 +43,8 @@

_M = TypeVar('_M', bound=proto.Message)
_R = TypeVar('_R')
JobPath = str
MessageId = str


class EngineException(Exception):
Expand Down Expand Up @@ -95,6 +98,136 @@ def instance(cls):
return cls._instance


class ResponseDemux:
"""A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.
Args:
cancel_callback: Function to be called when the future matching its request argument is
canceled.
"""

def __init__(self):
self._subscribers: Dict[JobPath, Tuple[MessageId, duet.AwaitableFuture]] = {}
self._next_available_message_id = 0

def subscribe(
self, request: quantum.QuantumRunStreamRequest
) -> duet.AwaitableFuture[quantum.QuantumRunStreamResponse]:
"""Assumes the message ID has not been set."""

if 'create_quantum_program_and_job' in request:
job_path = request.create_quantum_program_and_job.quantum_job.name
elif 'create_quantum_job' in request:
job_path = request.create_quantum_job.quantum_job.name
else: # 'get_quantum_result' in request
job_path = request.get_quantum_result.parent

request.message_id = self._next_available_message_id
response_future = duet.AwaitableFuture[quantum.QuantumRunStreamResponse]()
self._subscribers[job_path] = (self._next_available_message_id, response_future)
self._next_available_message_id += 1
return response_future

def publish(self, response: quantum.QuantumRunStreamResponse) -> None:
if 'error' in response:
job_path = next(
(
p
for p, (message_id, _) in self._subscribers.items()
if message_id == response.message_id
),
default='',
)
elif 'job' in response:
job_path = response.job.name
else: # 'result' in response
job_path = response.result.parent

if job_path not in self._subscribers:
return

self._subscribers[job_path].try_set_result(response)
del self._subscribers[job_path]

def publish_exception(self, exception: GoogleAPICallError) -> None:
"""Publishes an exception to all outstanding futures."""
for _, future in self._subscribers.values():
future.try_set_exception(exception)
self._subscribers = {}


class StreamManager:
def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
self._grpc_client = grpc_client
self._request_queue = asyncio.Queue()
self._response_demux = ResponseDemux()
self._manage_stream_loop_running = False

def _executor(self) -> AsyncioExecutor:
# We must re-use a single Executor due to multi-threading issues in gRPC
# clients: https://github.com/grpc/grpc/issues/25364.
return AsyncioExecutor.instance()

def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]:
async def iterator():
yield await self._request_queue.get()

# TODO how to make an iterator properly?
return iterator

async def _manage_stream(self):
"""Keeps the stream alive and routes responses to the appropriate request handler"""
while True:
try:
response_iterable = self._grpc_client.quantum_run_stream(self._request_iterator)
async for response in response_iterable:
self._response_demux.publish(response)
except GoogleAPICallError as e: # TODO what's the right error to check here?
self._response_demux.publish_exception(e)

async def _run_program(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
) -> quantum.QuantumResult:
"""This method is executed in a separate asyncio Task for each request."""
create_program_and_job_request = quantum.QuantumRunStreamRequest(
parent=project_name,
create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest(
parent=project_name, quantum_program=program, quantum_job=job
),
)
get_result_request = quantum.QuantumRunStreamRequest(
parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name)
)

response_future = self._response_demux.subscribe(create_program_and_job_request)
await self._request_queue.put(create_program_and_job_request)

response: Optional[quantum.QuantumRunStreamResponse] = None
while response is None:
try:
response = await response_future
except GoogleAPICallError:
response_future = self._response_demux.subscribe(get_result_request)
await self._request_queue.put(get_result_request)

if response.result is not None:
return response.result
# TODO handle QuantumJob response and retryable StreamError.

async def _cancel(self, job_name: str) -> None:
await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name))

def send(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
) -> duet.AwaitableFuture[quantum.QuantumResult]:
"""Sends a request over the stream and returns a future for the result."""
if not self._manage_stream_loop_running:
self._executor.submit(self._manage_stream)
result_future = self._executor.submit(self._run_program, project_name, program, job)
result_future.add_done_callback(lambda _: self._executor.submit(self._cancel, job.name))
# TODO will asyncio run_program task terminate when future is cancelled?
return result_future


class EngineClient:
"""Client for the Quantum Engine API handling protos and gRPC client.
Expand Down Expand Up @@ -148,6 +281,10 @@ async def make_client():

return self._executor.submit(make_client).result()

@cached_property
def stream_manager(self) -> StreamManager:
return StreamManager(self.grpc_client)

async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R:
"""Sends a request by invoking an asyncio callable."""
return await self._run_retry_async(func, request)
Expand Down Expand Up @@ -277,7 +414,7 @@ async def list_programs_async(
val = _date_or_time_to_filter_expr('created_before', created_before)
filters.append(f"create_time <= {val}")
if has_labels is not None:
for (k, v) in has_labels.items():
for k, v in has_labels.items():
filters.append(f"labels.{k}:{v}")
request = quantum.ListQuantumProgramsRequest(
parent=_project_name(project_id), filter=" AND ".join(filters)
Expand Down Expand Up @@ -528,7 +665,7 @@ async def list_jobs_async(
val = _date_or_time_to_filter_expr('created_before', created_before)
filters.append(f"create_time <= {val}")
if has_labels is not None:
for (k, v) in has_labels.items():
for k, v in has_labels.items():
filters.append(f"labels.{k}:{v}")
if execution_states is not None:
state_filter = []
Expand Down Expand Up @@ -744,6 +881,55 @@ async def get_job_results_async(

get_job_results = duet.sync(get_job_results_async)

def run_job_over_stream(
self,
project_id: str,
program_id: str,
job_id: Optional[str],
processor_ids: Sequence[str],
run_context: any_pb2.Any,
priority: Optional[int] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
) -> duet.AwaitableFuture[quantum.QuantumResult]:
# Check program to run and program parameters.
if priority and not 0 <= priority < 1000:
raise ValueError('priority must be between 0 and 1000')

# Create job.
job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else ''
job = quantum.QuantumJob(
name=job_name,
scheduling_config=quantum.SchedulingConfig(
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor_names=[
_processor_name_from_ids(project_id, processor_id)
for processor_id in processor_ids
]
)
),
run_context=run_context,
)
if priority:
job.scheduling_config.priority = priority
if description:
job.description = description
if labels:
job.labels.update(labels)
job_request = quantum.CreateQuantumJobRequest(
parent=_program_name_from_ids(project_id, program_id),
quantum_job=job,
overwrite_existing_run_context=False,
)
stream_request = quantum.QuantumRunStreamRequest(
message_id=self._msg_id_generator.generate(),
parent=_project_name(project_id),
create_quantum_job=job_request,
)
return self.stream_manager.send(stream_request)

# TODO NEXT UP: change to sending over QuantumProgram instead...

async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]:
"""Returns a list of Processors that the user has visibility to in the
current Engine project. The names of these processors are used to
Expand Down

0 comments on commit 03186ee

Please sign in to comment.