Skip to content

Commit

Permalink
Use SSL by default (#67)
Browse files Browse the repository at this point in the history
This PR enable SSL by default and makes the default SSL option
configurable for `parse_grpc_uri()`.
  • Loading branch information
llucax authored Aug 5, 2024
2 parents 2c55d40 + 47738a4 commit 3fd2eed
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 17 deletions.
3 changes: 2 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

## Upgrading

<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
- The `parse_grpc_uri` function (and `BaseApiClient` constructor) now enables SSL by default (`ssl=false` should be passed to disable it).
- The `parse_grpc_uri` function now accepts an optional `default_ssl` parameter to set the default value for the `ssl` parameter when not present in the URI.

## New Features

Expand Down
15 changes: 11 additions & 4 deletions src/frequenz/client/base/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ def _to_bool(value: str) -> bool:


def parse_grpc_uri(
uri: str, channel_type: type[ChannelT], /, *, default_port: int = 9090
uri: str,
channel_type: type[ChannelT],
/,
*,
default_port: int = 9090,
default_ssl: bool = True,
) -> ChannelT:
"""Create a grpclib client channel from a URI.
The URI must have the following format:
```
grpc://hostname[:port][?ssl=false]
grpc://hostname[:port][?ssl=<bool>]
```
A few things to consider about URI components:
Expand All @@ -39,14 +44,15 @@ def parse_grpc_uri(
- If the port is omitted, the `default_port` is used.
- If a query parameter is passed many times, the last value is used.
- The only supported query parameter is `ssl`, which must be a boolean value and
defaults to `false`.
defaults to the `default_ssl` argument if not present.
- Boolean query parameters can be specified with the following values
(case-insensitive): `true`, `1`, `on`, `false`, `0`, `off`.
Args:
uri: The gRPC URI specifying the connection parameters.
channel_type: The type of channel to create.
default_port: The default port number to use if the URI does not specify one.
default_ssl: The default SSL setting to use if the URI does not specify one.
Returns:
A grpclib client channel object.
Expand All @@ -69,7 +75,8 @@ def parse_grpc_uri(
)

options = {k: v[-1] for k, v in parse_qs(parsed_uri.query).items()}
ssl = _to_bool(options.pop("ssl", "false"))
ssl_option = options.pop("ssl", None)
ssl = _to_bool(ssl_option) if ssl_option is not None else default_ssl
if options:
raise ValueError(
f"Unexpected query parameters {options!r} in the URI '{uri}'",
Expand Down
61 changes: 49 additions & 12 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Test cases for the channel module."""

from dataclasses import dataclass
from typing import NotRequired, TypedDict
from unittest import mock

import pytest
Expand All @@ -12,8 +13,8 @@
from frequenz.client.base.channel import parse_grpc_uri

VALID_URLS = [
("grpc://localhost", "localhost", 9090, False),
("grpc://localhost:1234", "localhost", 1234, False),
("grpc://localhost", "localhost", 9090, True),
("grpc://localhost:1234", "localhost", 1234, True),
("grpc://localhost:1234?ssl=true", "localhost", 1234, True),
("grpc://localhost:1234?ssl=false", "localhost", 1234, False),
("grpc://localhost:1234?ssl=1", "localhost", 1234, True),
Expand All @@ -29,12 +30,25 @@
]


class _CreateChannelKwargs(TypedDict):
default_port: NotRequired[int]
default_ssl: NotRequired[bool]


@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS)
def test_grpclib_parse_uri_ok(
@pytest.mark.parametrize(
"default_port", [None, 9090, 1234], ids=lambda x: f"default_port={x}"
)
@pytest.mark.parametrize(
"default_ssl", [None, True, False], ids=lambda x: f"default_ssl={x}"
)
def test_grpclib_parse_uri_ok( # pylint: disable=too-many-arguments
uri: str,
host: str,
port: int,
ssl: bool,
default_port: int | None,
default_ssl: bool | None,
) -> None:
"""Test successful parsing of gRPC URIs using grpclib."""

Expand All @@ -44,24 +58,39 @@ class _FakeChannel:
port: int
ssl: bool

kwargs = _CreateChannelKwargs()
if default_port is not None:
kwargs["default_port"] = default_port
if default_ssl is not None:
kwargs["default_ssl"] = default_ssl

expected_port = port if f":{port}" in uri or default_port is None else default_port
expected_ssl = ssl if "ssl" in uri or default_ssl is None else default_ssl

with mock.patch(
"frequenz.client.base.channel._grpchacks.grpclib_create_channel",
return_value=_FakeChannel(host, port, ssl),
):
channel = parse_grpc_uri(uri, _grpchacks.GrpclibChannel)
) as create_channel_mock:
channel = parse_grpc_uri(uri, _grpchacks.GrpclibChannel, **kwargs)

assert isinstance(channel, _FakeChannel)
assert channel.host == host
assert channel.port == port
assert channel.ssl == ssl
create_channel_mock.assert_called_once_with(host, expected_port, expected_ssl)


@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS)
def test_grpcio_parse_uri_ok(
@pytest.mark.parametrize(
"default_port", [None, 9090, 1234], ids=lambda x: f"default_port={x}"
)
@pytest.mark.parametrize(
"default_ssl", [None, True, False], ids=lambda x: f"default_ssl={x}"
)
def test_grpcio_parse_uri_ok( # pylint: disable=too-many-arguments,too-many-locals
uri: str,
host: str,
port: int,
ssl: bool,
default_port: int | None,
default_ssl: bool | None,
) -> None:
"""Test successful parsing of gRPC URIs using grpcio."""
expected_channel = mock.MagicMock(
Expand All @@ -70,6 +99,14 @@ def test_grpcio_parse_uri_ok(
expected_credentials = mock.MagicMock(
name="mock_credentials", spec=_grpchacks.GrpcioChannel
)
expected_port = port if f":{port}" in uri or default_port is None else default_port
expected_ssl = ssl if "ssl" in uri or default_ssl is None else default_ssl

kwargs = _CreateChannelKwargs()
if default_port is not None:
kwargs["default_port"] = default_port
if default_ssl is not None:
kwargs["default_ssl"] = default_ssl

with (
mock.patch(
Expand All @@ -85,11 +122,11 @@ def test_grpcio_parse_uri_ok(
return_value=expected_credentials,
) as ssl_channel_credentials_mock,
):
channel = parse_grpc_uri(uri, _grpchacks.GrpcioChannel)
channel = parse_grpc_uri(uri, _grpchacks.GrpcioChannel, **kwargs)

assert channel == expected_channel
expected_target = f"{host}:{port}"
if ssl:
expected_target = f"{host}:{expected_port}"
if expected_ssl:
ssl_channel_credentials_mock.assert_called_once_with()
secure_channel_mock.assert_called_once_with(
expected_target, expected_credentials
Expand Down

0 comments on commit 3fd2eed

Please sign in to comment.