Skip to content

Commit

Permalink
feat: add support for async rest streaming methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ohmayr committed Sep 11, 2024
1 parent 6a42d1c commit 604a4bf
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ from google.api_core import exceptions as core_exceptions
from google.api_core import gapic_v1
from google.api_core import retry_async as retries
from google.api_core import rest_helpers
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2137): raise an import error if an older version of google.api.core is installed. #}
from google.api_core import rest_streaming_async # type: ignore

from google.protobuf import json_format

Expand Down Expand Up @@ -112,19 +114,19 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
def __hash__(self):
return hash("Async{{service.name}}RestTransport.{{method.name}}")

{% if method.http_options and not method.client_streaming and not method.server_streaming and not method.lro and not method.extended_lro and not method.paged_result_field %}
{% if method.http_options and not method.client_streaming and not method.lro and not method.extended_lro and not method.paged_result_field %}
{% set body_spec = method.http_options[0].body %}
{{ shared_macros.response_method(body_spec, is_async=True)|indent(8) }}

{% endif %}{# method.http_options and not method.client_streaming and not method.server_streaming and not method.lro and not method.extended_lro and not method.paged_result_field #}
{% endif %}{# method.http_options and not method.client_streaming and not method.lro and not method.extended_lro and not method.paged_result_field #}
async def __call__(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
timeout: Optional[float]=None,
metadata: Sequence[Tuple[str, str]]=(),
{# TODO(b/362949446): Update the return type as we implement this for different method types. #}
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}None{% endif %}{% endif %}:
{% if method.http_options and not method.client_streaming and not method.server_streaming and not method.lro and not method.extended_lro and not method.paged_result_field %}
){% if not method.void %} -> {% if not method.server_streaming %}{{method.output.ident}}{% else %}rest_streaming_async.AsyncResponseIterator{% endif %}{% endif %}:
{% if method.http_options and not method.client_streaming and not method.lro and not method.extended_lro and not method.paged_result_field %}
r"""Call the {{- ' ' -}}
{{ (method.name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
Expand All @@ -151,14 +153,18 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):

{% if not method.void %}
# Return the response
{% if method.server_streaming %}
resp = rest_streaming_async.AsyncResponseIterator(response, {{method.output.ident}})
{% else %}
resp = {{method.output.ident}}()
{% if method.output.ident.is_proto_plus_type %}
pb_resp = {{method.output.ident}}.pb(resp)
{% else %}
pb_resp = resp
{% endif %}
{% endif %}{# if method.output.ident.is_proto_plus_type #}
content = await response.read()
json_format.Parse(content, pb_resp, ignore_unknown_fields=True)
{% endif %}{# if method.server_streaming #}
return resp

{% endif %}{# method.void #}
Expand All @@ -167,7 +173,7 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
raise NotImplementedError(
"Method {{ method.name }} is not available over REST transport"
)
{% endif %}{# method.http_options and not method.client_streaming and not method.server_streaming and not method.lro and not method.extended_lro and not method.paged_result_field #}
{% endif %}{# method.http_options and not method.client_streaming and not method.lro and not method.extended_lro and not method.paged_result_field #}

{% endfor %}
{% for method in service.methods.values()|sort(attribute="name") %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ except ImportError: # pragma: NO COVER
import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable
from collections.abc import Iterable, AsyncIterable
from google.protobuf import json_format
import json
{% endif %}
Expand Down Expand Up @@ -95,6 +95,11 @@ from google.iam.v1 import policy_pb2 # type: ignore
{% endfilter %}
{{ shared_macros.add_google_api_core_version_header_import(service.version) }}

async def mock_async_gen(data, chunk_size=1):
for i in range(0, len(data)): # pragma: NO COVER
chunk = data[i : i + chunk_size]
yield chunk.encode("utf-8")

def client_cert_source_callback():
return b"cert bytes", b"key bytes"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def test_{{ method_name }}_raw_page_lro():
{# NOTE: This guard is added to avoid generating duplicate tests for methods which are tested elsewhere.
# TODO(https://github.com/googleapis/gapic-generator-python/issues/2143): Remove the test `test_{{ method_name }}_rest` from here once the linked issue is resolved.
#}
{% if method.server_streaming or method.lro or method.extended_lro or method.paged_result_field %}
{% if method.lro or method.extended_lro or method.paged_result_field %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand Down Expand Up @@ -1909,13 +1909,13 @@ def test_unsupported_parameter_rest_asyncio():
{% endmacro %}

{# is_rest_unsupported_method renders:
# 'True' if transport is async REST.
# 'True' if transport is sync REST and method is a client streaming method.
# 'True' if transport is async REST and method is one of [client_streaming, lro, extended_lro, paged_result_field].
# 'True' if transport is sync REST and method is a client_streaming method.
# 'False' otherwise.
#}
{# NOTE: We will keep updating this method as we add support for methods in async REST. #}
{% macro is_rest_unsupported_method(method, is_async) %}
{%- if method.client_streaming or (is_async and (method.server_streaming or method.lro or method.extended_lro or method.paged_result_field)) -%}
{%- if method.client_streaming or (is_async and (method.lro or method.extended_lro or method.paged_result_field)) -%}
{{'True'}}
{%- else -%}
{{'False'}}
Expand Down Expand Up @@ -2038,7 +2038,7 @@ def test_{{transport_name}}_initialize_client():
{# rest_method_call_success_test generates tests for rest methods
# when they make a successful request.
# NOTE: Currently, this macro does not support the following method
# types: [method.server_streaming, method.lro, method.extended_lro, method.paged_result_field].
# types: [method.lro, method.extended_lro, method.paged_result_field].
# As support is added for the above methods, the relevant guard can be removed from within the macro
# TODO(https://github.com/googleapis/gapic-generator-python/issues/2142): Clean up `rest_required_tests` once support for all the methods metioned above is added here.
#}
Expand All @@ -2051,7 +2051,7 @@ def test_{{transport_name}}_initialize_client():
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2143): Update the guard below as support we add for each method. Remove it once we have
# all the methods supported in async rest transport that are supported in sync rest transport.
#}
{% if not (method.server_streaming or method.lro or method.extended_lro or method.paged_result_field)%}
{% if not (method.lro or method.extended_lro or method.paged_result_field)%}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand Down Expand Up @@ -2107,13 +2107,33 @@ def test_{{transport_name}}_initialize_client():
{% endif %}{# method.output.ident.is_proto_plus_type #}
json_return_value = json_format.MessageToJson(return_value)
{% endif %}{# method.void #}
{% if method.server_streaming %}
json_return_value = "[{}]".format(json_return_value)
{% if is_async %}
response_value.content.return_value = mock_async_gen(json_return_value)
{% else %}{# not is_async #}
response_value.iter_content = mock.Mock(return_value=iter(json_return_value))
{% endif %}{# is_async #}
{% else %}{# not method.streaming #}
{% if is_async %}
response_value.read = mock.AsyncMock(return_value=json_return_value.encode('UTF-8'))
{% else %}{# is_async #}
{% else %}{# not is_async #}
response_value.content = json_return_value.encode('UTF-8')
{% endif %}{# is_async #}
{% endif %}{# method.server_streaming #}
req.return_value = response_value
response = {{ await_prefix }}client.{{ method_name }}(request)

{% if method.server_streaming %}
{% if is_async %}
assert isinstance(response, AsyncIterable)
response = await response.__anext__()
{% else %}
assert isinstance(response, Iterable)
response = next(response)
{% endif %}
{% endif %}

# Establish that the response is the type that we expect.
{% if method.void %}
assert response is None
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def showcase_unit_w_rest_async(
session.chdir(lib)
# Note: google-api-core and google-auth are re-installed here to override the version installed in constraints.
# TODO(https://github.com/googleapis/python-api-core/pull/694): Update the version of google-api-core once the linked PR is merged.
session.install('--no-cache-dir', '--force-reinstall', "google-api-core[grpc]@git+https://github.com/googleapis/python-api-core.git@7dea20d73878eca93b61bb82ae6ddf335fb3a8ca")
session.install('--no-cache-dir', '--force-reinstall', "google-api-core[grpc]@git+https://github.com/googleapis/python-api-core.git@16038182329055551a32acd0f9f505301be4bcc5")
# TODO(https://github.com/googleapis/google-auth-library-python/pull/1577): Update the version of google-auth once the linked PR is merged.
session.install('--no-cache-dir', '--force-reinstall', "google-auth@git+https://github.com/googleapis/google-auth-library-python.git@add-support-for-async-authorized-session-api")
session.install("aiohttp")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import grpc
from grpc.experimental import aio
from collections.abc import Iterable
from collections.abc import Iterable, AsyncIterable
from google.protobuf import json_format
import json
import math
Expand Down Expand Up @@ -71,6 +71,11 @@
import google.auth


async def mock_async_gen(data, chunk_size=1):
for i in range(0, len(data)): # pragma: NO COVER
chunk = data[i : i + chunk_size]
yield chunk.encode("utf-8")

def client_cert_source_callback():
return b"cert bytes", b"key bytes"

Expand Down Expand Up @@ -15289,6 +15294,7 @@ def test_batch_get_assets_history_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.batch_get_assets_history(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.BatchGetAssetsHistoryResponse)

Expand Down Expand Up @@ -15349,6 +15355,7 @@ def test_create_feed_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.create_feed(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.Feed)
assert response.name == 'name_value'
Expand Down Expand Up @@ -15418,6 +15425,7 @@ def test_get_feed_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.get_feed(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.Feed)
assert response.name == 'name_value'
Expand Down Expand Up @@ -15482,6 +15490,7 @@ def test_list_feeds_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.list_feeds(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.ListFeedsResponse)

Expand Down Expand Up @@ -15542,6 +15551,7 @@ def test_update_feed_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.update_feed(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.Feed)
assert response.name == 'name_value'
Expand Down Expand Up @@ -15603,6 +15613,7 @@ def test_delete_feed_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.delete_feed(request)

# Establish that the response is the type that we expect.
assert response is None

Expand Down Expand Up @@ -15701,6 +15712,7 @@ def test_analyze_iam_policy_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.analyze_iam_policy(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.AnalyzeIamPolicyResponse)
assert response.fully_explored is True
Expand Down Expand Up @@ -15778,6 +15790,7 @@ def test_analyze_move_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.analyze_move(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.AnalyzeMoveResponse)

Expand Down Expand Up @@ -15835,6 +15848,7 @@ def test_query_assets_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.query_assets(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.QueryAssetsResponse)
assert response.job_reference == 'job_reference_value'
Expand Down Expand Up @@ -15897,6 +15911,7 @@ def test_create_saved_query_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.create_saved_query(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.SavedQuery)
assert response.name == 'name_value'
Expand Down Expand Up @@ -15963,6 +15978,7 @@ def test_get_saved_query_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.get_saved_query(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.SavedQuery)
assert response.name == 'name_value'
Expand Down Expand Up @@ -16050,6 +16066,7 @@ def test_update_saved_query_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.update_saved_query(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.SavedQuery)
assert response.name == 'name_value'
Expand Down Expand Up @@ -16109,6 +16126,7 @@ def test_delete_saved_query_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.delete_saved_query(request)

# Establish that the response is the type that we expect.
assert response is None

Expand Down Expand Up @@ -16164,6 +16182,7 @@ def test_batch_get_effective_iam_policies_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.batch_get_effective_iam_policies(request)

# Establish that the response is the type that we expect.
assert isinstance(response, asset_service.BatchGetEffectiveIamPoliciesResponse)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import grpc
from grpc.experimental import aio
from collections.abc import Iterable
from collections.abc import Iterable, AsyncIterable
from google.protobuf import json_format
import json
import math
Expand Down Expand Up @@ -61,6 +61,11 @@
import google.auth


async def mock_async_gen(data, chunk_size=1):
for i in range(0, len(data)): # pragma: NO COVER
chunk = data[i : i + chunk_size]
yield chunk.encode("utf-8")

def client_cert_source_callback():
return b"cert bytes", b"key bytes"

Expand Down Expand Up @@ -3307,6 +3312,7 @@ def test_generate_access_token_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.generate_access_token(request)

# Establish that the response is the type that we expect.
assert isinstance(response, common.GenerateAccessTokenResponse)
assert response.access_token == 'access_token_value'
Expand Down Expand Up @@ -3364,6 +3370,7 @@ def test_generate_id_token_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.generate_id_token(request)

# Establish that the response is the type that we expect.
assert isinstance(response, common.GenerateIdTokenResponse)
assert response.token == 'token_value'
Expand Down Expand Up @@ -3422,6 +3429,7 @@ def test_sign_blob_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.sign_blob(request)

# Establish that the response is the type that we expect.
assert isinstance(response, common.SignBlobResponse)
assert response.key_id == 'key_id_value'
Expand Down Expand Up @@ -3482,6 +3490,7 @@ def test_sign_jwt_rest_call_success(request_type):
response_value.content = json_return_value.encode('UTF-8')
req.return_value = response_value
response = client.sign_jwt(request)

# Establish that the response is the type that we expect.
assert isinstance(response, common.SignJwtResponse)
assert response.key_id == 'key_id_value'
Expand Down
Loading

0 comments on commit 604a4bf

Please sign in to comment.