Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish implementation of connection to multiple cluster with relation aliases #7

Merged
merged 7 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions lib/charms/data_platform_libs/v0/database_requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
relation_name: str,
database_name: str,
extra_user_roles: str = None,
relations_aliases: List[str] = None,
):
"""Manager of database client relations."""
super().__init__(charm, relation_name)
Expand All @@ -191,13 +192,68 @@ def __init__(
self.local_app = self.charm.model.app
self.local_unit = self.charm.unit
self.relation_name = relation_name
self.relations_aliases = relations_aliases
self.framework.observe(
self.charm.on[relation_name].relation_joined, self._on_relation_joined_event
)
self.framework.observe(
self.charm.on[relation_name].relation_changed, self._on_relation_changed_event
)

# Define custom event names for each alias.
if relations_aliases:
# Ensure the number of aliases match the maximum number
# of connections allowed in the specific relation.
relation_connection_limit = self.charm.meta.requires[relation_name].limit
if len(relations_aliases) != relation_connection_limit:
raise ValueError(
f"The number of aliases must match the maximum number of connections allowed in the relation. "
f"Expected {relation_connection_limit}, got {len(relations_aliases)}"
)

for relation_alias in relations_aliases:
self.on.define_event(f"{relation_alias}_database_created", DatabaseCreatedEvent)
self.on.define_event(
f"{relation_alias}_endpoints_changed", DatabaseEndpointsChangedEvent
)
self.on.define_event(
f"{relation_alias}_read_only_endpoints_changed",
DatabaseReadOnlyEndpointsChangedEvent,
)

def _assign_relation_alias(self, relation_id: int) -> None:
"""Assigns an alias to a relation.

This function writes in the application data bag, therefore,
only the leader unit can call it.

Args:
relation_id: the identifier for a particular relation.
"""
# If this unit isn't the leader or no aliases were provided, return immediately.
if not self.local_unit.is_leader() or not self.relations_aliases:
return

# Return if an alias was already assigned to this relation
# (like when there are more than one unit joining the relation).
if (
self.charm.model.get_relation(self.relation_name, relation_id)
.data[self.local_app]
.get("alias")
):
return

# Retrieve the available aliases (the ones that weren't assigned to any relation).
available_aliases = self.relations_aliases[:]
for relation in self.charm.model.relations[self.relation_name]:
alias = relation.data[self.local_app].get("alias")
if alias:
logger.debug(f"Alias {alias} was already assigned to relation {relation.id}")
available_aliases.remove(alias)

# Set the alias in the application relation databag of the specific relation.
self._update_relation_data(relation_id, {"alias": available_aliases[0]})

def _diff(self, event: RelationChangedEvent) -> Diff:
"""Retrieves the diff of the data in the relation changed databag.

Expand Down Expand Up @@ -233,6 +289,31 @@ def _diff(self, event: RelationChangedEvent) -> Diff:
# Return the diff with all possible changes.
return Diff(added, changed, deleted)

def _emit_aliased_event(self, relation: Relation, event_name: str) -> None:
"""Emit an aliased event to a particular relation if it has an alias.

Args:
relation: a particular relation.
event_name: the name of the event to emit.
"""
alias = self._get_relation_alias(relation.id)
if alias:
getattr(self.on, f"{alias}_{event_name}").emit(relation)

def _get_relation_alias(self, relation_id: int) -> Optional[str]:
"""Returns the relation alias.

Args:
relation_id: the identifier for a particular relation.

Returns:
the relation alias or None if the relation wasn't found.
"""
for relation in self.charm.model.relations[self.relation_name]:
if relation.id == relation_id:
return relation.data[self.local_app].get("alias")
return None

def fetch_relation_data(self) -> dict:
"""Retrieves data from relation.

Expand Down Expand Up @@ -267,6 +348,9 @@ def _update_relation_data(self, relation_id: int, data: dict) -> None:

def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None:
"""Event emitted when the application joins the database relation."""
# If relations aliases were provided, assign one to the relation.
self._assign_relation_alias(event.relation.id)

# Sets both database and extra user roles in the relation
# if the roles are provided. Otherwise, sets only the database.
if self.extra_user_roles:
Expand All @@ -292,20 +376,32 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None:
# Check if the database is created
# (the database charm shared the credentials).
if "username" in diff.added and "password" in diff.added:
# Emit the default event (the one without an alias).
self.on.database_created.emit(event.relation)

# Emit the aliased event (if any).
self._emit_aliased_event(event.relation, "database_created")

# Emit an endpoints changed event if the database
# added or changed this info in the relation databag.
if "endpoints" in diff.added or "endpoints" in diff.changed:
# Emit the default event (the one without an alias).
logger.info(f"endpoints changed on {datetime.now()}")
self.on.endpoints_changed.emit(event.relation)

# Emit the aliased event (if any).
self._emit_aliased_event(event.relation, "endpoints_changed")

# Emit a read only endpoints changed event if the database
# added or changed this info in the relation databag.
if "read-only-endpoints" in diff.added or "read-only-endpoints" in diff.changed:
# Emit the default event (the one without an alias).
logger.info(f"read-only-endpoints changed on {datetime.now()}")
self.on.read_only_endpoints_changed.emit(event.relation)

# Emit the aliased event (if any).
self._emit_aliased_event(event.relation, "read_only_endpoints_changed")

@property
def relations(self) -> List[Relation]:
"""The list of Relation instances associated with this relation_name."""
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/application-charm/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ requires:
interface: database-client
second-database:
interface: database-client
multiple-database-clusters:
interface: database-client
aliased-multiple-database-clusters:
interface: database-client
limit: 2
81 changes: 81 additions & 0 deletions tests/integration/application-charm/src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,49 @@ def __init__(self, *args):
self.second_database.on.endpoints_changed, self._on_second_database_endpoints_changed
)

# Multiple database clusters charm events (clusters/relations without alias).
database_name = f'{self.app.name.replace("-", "_")}_multiple_database_clusters'
self.database_clusters = DatabaseRequires(
self, "multiple-database-clusters", database_name, EXTRA_USER_ROLES
)
self.framework.observe(
self.database_clusters.on.database_created, self._on_cluster_database_created
)
self.framework.observe(
self.database_clusters.on.endpoints_changed,
self._on_cluster_endpoints_changed,
)

# Multiple database clusters charm events (defined dynamically
# in the database requires charm library, using the provided cluster/relation aliases).
database_name = f'{self.app.name.replace("-", "_")}_aliased_multiple_database_clusters'
cluster_aliases = ["cluster1", "cluster2"] # Aliases for the multiple clusters/relations.
self.aliased_database_clusters = DatabaseRequires(
self,
"aliased-multiple-database-clusters",
database_name,
EXTRA_USER_ROLES,
cluster_aliases,
)
# Each database cluster will have its own events
# with the name having the cluster/relation alias as the prefix.
self.framework.observe(
self.aliased_database_clusters.on.cluster1_database_created,
self._on_cluster1_database_created,
)
self.framework.observe(
self.aliased_database_clusters.on.cluster1_endpoints_changed,
self._on_cluster1_endpoints_changed,
)
self.framework.observe(
self.aliased_database_clusters.on.cluster2_database_created,
self._on_cluster2_database_created,
)
self.framework.observe(
self.aliased_database_clusters.on.cluster2_endpoints_changed,
self._on_cluster2_endpoints_changed,
)

def _on_start(self, _) -> None:
"""Only sets an Active status."""
self.unit.status = ActiveStatus()
Expand All @@ -86,6 +129,44 @@ def _on_second_database_endpoints_changed(self, event: DatabaseEndpointsChangedE
"""Event triggered when the read/write endpoints of the database change."""
logger.info(f"second database endpoints have been changed to: {event.endpoints}")

# Multiple database clusters events observers.
def _on_cluster_database_created(self, event: DatabaseCreatedEvent) -> None:
"""Event triggered when a database was created for this application."""
# Retrieve the credentials using the charm library.
logger.info(
f"cluster {event.relation.app.name} credentials: {event.username} {event.password}"
)
self.unit.status = ActiveStatus(
f"received database credentials for cluster {event.relation.app.name}"
)

def _on_cluster_endpoints_changed(self, event: DatabaseEndpointsChangedEvent) -> None:
"""Event triggered when the read/write endpoints of the database change."""
logger.info(
f"cluster {event.relation.app.name} endpoints have been changed to: {event.endpoints}"
)

# Multiple database clusters events observers (for aliased clusters/relations).
def _on_cluster1_database_created(self, event: DatabaseCreatedEvent) -> None:
"""Event triggered when a database was created for this application."""
# Retrieve the credentials using the charm library.
logger.info(f"cluster1 credentials: {event.username} {event.password}")
self.unit.status = ActiveStatus("received database credentials for cluster1")

def _on_cluster1_endpoints_changed(self, event: DatabaseEndpointsChangedEvent) -> None:
"""Event triggered when the read/write endpoints of the database change."""
logger.info(f"cluster1 endpoints have been changed to: {event.endpoints}")

def _on_cluster2_database_created(self, event: DatabaseCreatedEvent) -> None:
"""Event triggered when a database was created for this application."""
# Retrieve the credentials using the charm library.
logger.info(f"cluster2 credentials: {event.username} {event.password}")
self.unit.status = ActiveStatus("received database credentials for cluster2")

def _on_cluster2_endpoints_changed(self, event: DatabaseEndpointsChangedEvent) -> None:
"""Event triggered when the read/write endpoints of the database change."""
logger.info(f"cluster2 endpoints have been changed to: {event.endpoints}")


if __name__ == "__main__":
main(ApplicationCharm)
37 changes: 31 additions & 6 deletions tests/integration/helpers.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,44 @@
#!/usr/bin/env python3
# Copyright 2022 Canonical Ltd.
# See LICENSE file for licensing details.
import json
from typing import Optional

import yaml
from pytest_operator.plugin import OpsTest


async def build_connection_string(
ops_test: OpsTest, application_name: str, relation_name: str
ops_test: OpsTest,
application_name: str,
relation_name: str,
*,
relation_id: str = None,
relation_alias: str = None,
) -> str:
"""Build a PostgreSQL connection string.

Args:
ops_test: The ops test framework instance
application_name: The name of the application
relation_name: name of the relation to get connection data from
relation_id: id of the relation to get connection data from
relation_alias: alias of the relation (like a connection name)
to get connection data from

Returns:
a PostgreSQL connection string
"""
# Get the connection data exposed to the application through the relation.
database = f'{application_name.replace("-", "_")}_{relation_name.replace("-", "_")}'
username = await get_application_relation_data(
ops_test, application_name, relation_name, "username"
ops_test, application_name, relation_name, "username", relation_id, relation_alias
)
password = await get_application_relation_data(
ops_test, application_name, relation_name, "password"
ops_test, application_name, relation_name, "password", relation_id, relation_alias
)
endpoints = await get_application_relation_data(
ops_test, application_name, relation_name, "endpoints"
ops_test, application_name, relation_name, "endpoints", relation_id, relation_alias
)
host = endpoints.split(",")[0].split(":")[0]

Expand All @@ -42,6 +51,8 @@ async def get_application_relation_data(
application_name: str,
relation_name: str,
key: str,
relation_id: str = None,
relation_alias: str = None,
) -> Optional[str]:
"""Get relation data for an application.

Expand All @@ -50,14 +61,18 @@ async def get_application_relation_data(
application_name: The name of the application
relation_name: name of the relation to get connection data from
key: key of data to be retrieved
relation_id: id of the relation to get connection data from
relation_alias: alias of the relation (like a connection name)
to get connection data from

Returns:
the that that was requested or None
if no data in the relation

Raises:
ValueError if it's not possible to get application unit data
or if there is no data for the particular relation endpoint.
or if there is no data for the particular relation endpoint
and/or alias.
"""
unit_name = f"{application_name}/0"
raw_data = (await ops_test.juju("show-unit", unit_name))[1]
Expand All @@ -66,8 +81,18 @@ async def get_application_relation_data(
data = yaml.safe_load(raw_data)
# Filter the data based on the relation name.
relation_data = [v for v in data[unit_name]["relation-info"] if v["endpoint"] == relation_name]
if relation_id:
# Filter the data based on the relation id.
relation_data = [v for v in relation_data if v["relation-id"] == relation_id]
if relation_alias:
# Filter the data based on the cluster/relation alias.
relation_data = [
v
for v in relation_data
if json.loads(v["application-data"]["data"])["alias"] == relation_alias
]
if len(relation_data) == 0:
raise ValueError(
f"no relation data could be grabbed on relation with endpoint {relation_name}"
f"no relation data could be grabbed on relation with endpoint {relation_name} and alias {relation_alias}"
)
return relation_data[0]["application-data"].get(key)
Loading