From 806e4635321343c1791474534a7333c95b6540a8 Mon Sep 17 00:00:00 2001 From: Noctua Date: Mon, 15 Jan 2024 11:59:59 +0100 Subject: [PATCH] chore: update charm libraries (#270) Co-authored-by: Github Actions --- .../v2/tls_certificates.py | 179 +++++++++++------- 1 file changed, 114 insertions(+), 65 deletions(-) diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index 99741f5..b8855be 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -308,7 +308,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 20 +LIBPATCH = 21 PYDEPS = ["cryptography", "jsonschema"] @@ -693,6 +693,105 @@ def generate_ca( return cert.public_bytes(serialization.Encoding.PEM) +def get_certificate_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + alt_names: Optional[List[str]], + is_ca: bool, +) -> List[x509.Extension]: + """Generates a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + + sans: List[x509.GeneralName] = [] + san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] + sans.extend(san_alt_names) + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + def generate_certificate( csr: bytes, ca: bytes, @@ -730,74 +829,24 @@ def generate_certificate( .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=validity)) - .add_extension( - x509.AuthorityKeyIdentifier( - key_identifier=ca_pem.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ) - .add_extension( - x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False - ) ) - - extensions_list = csr_object.extensions - san_ext: Optional[x509.Extension] = None - if alt_names: - full_sans_dns = alt_names.copy() + extensions = get_certificate_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + alt_names=alt_names, + is_ca=is_ca, + ) + for extension in extensions: try: - loaded_san_ext = csr_object.extensions.get_extension_for_class( - x509.SubjectAlternativeName + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, ) - full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) - except ExtensionNotFound: - pass - finally: - san_ext = Extension( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - False, - x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), - ) - if not extensions_list: - extensions_list = x509.Extensions([san_ext]) - - for extension in extensions_list: - if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: - extension = san_ext - - certificate_builder = certificate_builder.add_extension( - extension.value, - critical=extension.critical, - ) - - if is_ca: - certificate_builder = certificate_builder.add_extension( - x509.BasicConstraints(ca=True, path_length=None), critical=True - ) - certificate_builder = certificate_builder.add_extension( - x509.KeyUsage( - digital_signature=False, - content_commitment=False, - key_encipherment=False, - data_encipherment=False, - key_agreement=False, - key_cert_sign=True, - crl_sign=True, - encipher_only=False, - decipher_only=False, - ), - critical=True, - ) - else: - certificate_builder = certificate_builder.add_extension( - x509.BasicConstraints(ca=False, path_length=None), critical=False - ) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) - certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM)