From 23eb940d6bdf44390cacc606fd1d4d985c2ea71e Mon Sep 17 00:00:00 2001 From: Robin Krahl Date: Thu, 16 Nov 2023 21:23:17 +0100 Subject: [PATCH] Use enums instead of literals Using enums makes it easier to use the functions in a typesafe way and to validate user input. Fixes: https://github.com/Nitrokey/nethsm-sdk-py/issues/69 --- nethsm/__init__.py | 165 +++++++++++++++++++----------------- tests/conftest.py | 41 +++++---- tests/test_nethsm_config.py | 22 ++--- tests/test_nethsm_keys.py | 20 ++--- tests/test_nethsm_users.py | 4 +- 5 files changed, 135 insertions(+), 117 deletions(-) diff --git a/nethsm/__init__.py b/nethsm/__init__.py index 898655a..4795e5d 100644 --- a/nethsm/__init__.py +++ b/nethsm/__init__.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from datetime import datetime from io import BufferedReader, FileIO -from typing import TYPE_CHECKING, Any, Iterator, Literal, Mapping, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Union, cast from urllib.parse import urlencode import urllib3 @@ -79,6 +79,13 @@ class UnattendedBootStatus(enum.Enum): ON = "on" OFF = "off" + @staticmethod + def from_string(s: str) -> "UnattendedBootStatus": + for status in UnattendedBootStatus: + if status.value == s: + return status + raise ValueError(f"Unsupported unattended boot status {s}") + class KeyType(enum.Enum): RSA = "RSA" @@ -97,11 +104,6 @@ def from_string(s: str) -> "KeyType": raise ValueError(f"Unsupported key type {s}") -KeyTypeLitteral = Literal[ - "RSA", "Curve25519", "EC_P224", "EC_P256", "EC_P384", "EC_P521", "Generic" -] - - class KeyMechanism(enum.Enum): RSA_DECRYPTION_RAW = "RSA_Decryption_RAW" RSA_DECRYPTION_PKCS1 = "RSA_Decryption_PKCS1" @@ -123,33 +125,24 @@ class KeyMechanism(enum.Enum): AES_ENCRYPTION_CBC = "AES_Encryption_CBC" AES_DECRYPTION_CBC = "AES_Decryption_CBC" - -KeyMechanismLiteral = Literal[ - "RSA_Decryption_RAW", - "RSA_Decryption_PKCS1", - "RSA_Decryption_OAEP_MD5", - "RSA_Decryption_OAEP_SHA1", - "RSA_Decryption_OAEP_SHA224", - "RSA_Decryption_OAEP_SHA256", - "RSA_Decryption_OAEP_SHA384", - "RSA_Decryption_OAEP_SHA512", - "RSA_Signature_PKCS1", - "RSA_Signature_PSS_MD5", - "RSA_Signature_PSS_SHA1", - "RSA_Signature_PSS_SHA224", - "RSA_Signature_PSS_SHA256", - "RSA_Signature_PSS_SHA384", - "RSA_Signature_PSS_SHA512", - "EdDSA_Signature", - "ECDSA_Signature", - "AES_Encryption_CBC", - "AES_Decryption_CBC", -] + @staticmethod + def from_string(s: str) -> "KeyMechanism": + for key_mechanism in KeyMechanism: + if key_mechanism.value == s: + return key_mechanism + raise ValueError(f"Unsupported key mechanism {s}") class EncryptMode(enum.Enum): AES_CBC = "AES_CBC" + @staticmethod + def from_string(s: str) -> "EncryptMode": + for mode in EncryptMode: + if mode.value == s: + return mode + raise ValueError(f"Unsupported encrypt mode {s}") + class DecryptMode(enum.Enum): RAW = "RAW" @@ -162,6 +155,13 @@ class DecryptMode(enum.Enum): OAEP_SHA512 = "OAEP_SHA512" AES_CBC = "AES_CBC" + @staticmethod + def from_string(s: str) -> "DecryptMode": + for mode in DecryptMode: + if mode.value == s: + return mode + raise ValueError(f"Unsupported decrypt mode {s}") + class SignMode(enum.Enum): PKCS1 = "PKCS1" @@ -174,6 +174,13 @@ class SignMode(enum.Enum): EDDSA = "EdDSA" ECDSA = "ECDSA" + @staticmethod + def from_string(s: str) -> "SignMode": + for mode in SignMode: + if mode.value == s: + return mode + raise ValueError(f"Unsupported sign mode {s}") + class TlsKeyType(enum.Enum): RSA = "RSA" @@ -183,6 +190,13 @@ class TlsKeyType(enum.Enum): EC_P384 = "EC_P384" EC_P521 = "EC_P521" + @staticmethod + def from_string(s: str) -> "TlsKeyType": + for key_type in TlsKeyType: + if key_type.value == s: + return key_type + raise ValueError(f"Unsupported TLS key type {s}") + @dataclass class SystemInfo: @@ -202,7 +216,7 @@ class User: @dataclass class Key: key_id: str - mechanisms: list[str] + mechanisms: list[KeyMechanism] type: KeyType operations: int tags: Optional[list[str]] @@ -527,7 +541,7 @@ def get_user(self, user_id: str) -> User: def add_user( self, real_name: str, - role: Literal["Administrator", "Operator", "Metrics", "Backup"], + role: Role, passphrase: str, user_id: Optional[str] = None, ) -> str: @@ -536,7 +550,7 @@ def add_user( body = UserPostDataDict( realName=real_name, - role=role, + role=role.value, passphrase=passphrase, ) try: @@ -775,7 +789,9 @@ def get_key(self, key_id: str) -> Key: ) return Key( key_id=key_id, - mechanisms=[mechanism for mechanism in key.mechanisms], + mechanisms=[ + KeyMechanism.from_string(mechanism) for mechanism in key.mechanisms + ], type=KeyType.from_string(key.type), operations=key.operations, tags=[str(tag) for tag in cast(list[str], key.restrictions["tags"])] @@ -811,14 +827,15 @@ def get_key_public_key(self, key_id: str) -> str: def add_key( self, key_id: str, - type: KeyTypeLitteral, - mechanisms: list[KeyMechanismLiteral], + type: KeyType, + mechanisms: list[KeyMechanism], tags: list[str], prime_p: Optional[str], prime_q: Optional[str], public_exponent: Optional[str], data: Optional[str], ) -> str: + from .client.components.schema.key_mechanisms import KeyMechanismsTupleInput from .client.components.schema.key_private_data import KeyPrivateDataDict from .client.components.schema.key_restrictions import KeyRestrictionsDict from .client.components.schema.private_key import PrivateKeyDict @@ -827,7 +844,7 @@ def add_key( # To do: split into different methods for RSA and other key types, or # at least change typing accordingly - if type == "RSA": + if type == KeyType.RSA: assert prime_p assert prime_q assert public_exponent @@ -840,10 +857,14 @@ def add_key( assert data key_data = KeyPrivateDataDict(data=data) + mechanism_tuple: KeyMechanismsTupleInput = [ + mechanism.value for mechanism in mechanisms + ] + if tags: body = PrivateKeyDict( - type=type, - mechanisms=mechanisms, + type=type.value, + mechanisms=mechanism_tuple, key=key_data, restrictions=KeyRestrictionsDict( tags=TagListTuple([tag for tag in tags]) @@ -851,8 +872,8 @@ def add_key( ) else: body = PrivateKeyDict( - type=type, - mechanisms=mechanisms, + type=type.value, + mechanisms=mechanism_tuple, key=key_data, ) @@ -905,27 +926,31 @@ def delete_key(self, key_id: str) -> None: def generate_key( self, - type: KeyTypeLitteral, - mechanisms: tuple[KeyMechanismLiteral], + type: KeyType, + mechanisms: list[KeyMechanism], length: int, key_id: Optional[str] = None, ) -> str: from .client.components.schema.key_generate_request_data import ( KeyGenerateRequestDataDict, ) - from .client.components.schema.key_mechanisms import KeyMechanismsTuple + from .client.components.schema.key_mechanisms import KeyMechanismsTupleInput + + mechanism_tuple: KeyMechanismsTupleInput = [ + mechanism.value for mechanism in mechanisms + ] if key_id: body = KeyGenerateRequestDataDict( - type=type, - mechanisms=KeyMechanismsTuple(mechanisms), + type=type.value, + mechanisms=mechanism_tuple, length=length, id=key_id, ) else: body = KeyGenerateRequestDataDict( - type=type, - mechanisms=KeyMechanismsTuple(mechanisms), + type=type.value, + mechanisms=mechanism_tuple, length=length, ) try: @@ -1107,7 +1132,7 @@ def csr( def generate_tls_key( self, - type: Literal["RSA", "Curve25519", "EC_P224", "EC_P256", "EC_P384", "EC_P521"], + type: TlsKeyType, length: Optional[int] = None, ) -> None: from .client.components.schema.tls_key_generate_request_data import ( @@ -1116,7 +1141,7 @@ def generate_tls_key( from .client.schemas import Unset body = TlsKeyGenerateRequestDataDict( - type=type, + type=type.value, length=length if length is not None else Unset(), ) @@ -1213,11 +1238,13 @@ def set_logging_config( self, ip_address: str, port: int, - log_level: Literal["debug", "info", "warning", "error"], + log_level: LogLevel, ) -> None: from .client.components.schema.logging_config import LoggingConfigDict - body = LoggingConfigDict(ipAddress=ip_address, port=port, logLevel=log_level) + body = LoggingConfigDict( + ipAddress=ip_address, port=port, logLevel=log_level.value + ) try: self.get_api().config_logging_put(body=body) except Exception as e: @@ -1262,12 +1289,12 @@ def set_time(self, time: Union[str, datetime]) -> None: }, ) - def set_unattended_boot(self, status: Literal["on", "off"]) -> None: + def set_unattended_boot(self, status: UnattendedBootStatus) -> None: from .client.components.schema.unattended_boot_config import ( UnattendedBootConfigDict, ) - body = UnattendedBootConfigDict(status=status) + body = UnattendedBootConfigDict(status=status.value) try: self.get_api().config_unattended_boot_put(body=body) except Exception as e: @@ -1408,7 +1435,7 @@ def factory_reset(self) -> None: ) def encrypt( - self, key_id: str, data: str, mode: Literal["AES_CBC"], iv: str + self, key_id: str, data: str, mode: EncryptMode, iv: str ) -> tuple[str, str]: from .client.components.schema.encrypt_request_data import ( EncryptRequestDataDict, @@ -1418,7 +1445,7 @@ def encrypt( ) path_params = PathParametersDict(KeyID=key_id) - body = EncryptRequestDataDict(message=data, mode=mode, iv=iv) + body = EncryptRequestDataDict(message=data, mode=mode.value, iv=iv) try: response = self.get_api().keys_key_id_encrypt_post( path_params=path_params, body=body @@ -1439,17 +1466,7 @@ def decrypt( self, key_id: str, data: str, - mode: Literal[ - "RAW", - "PKCS1", - "OAEP_MD5", - "OAEP_SHA1", - "OAEP_SHA224", - "OAEP_SHA256", - "OAEP_SHA384", - "OAEP_SHA512", - "AES_CBC", - ], + mode: DecryptMode, iv: str, ) -> str: from .client.components.schema.decrypt_request_data import ( @@ -1459,10 +1476,10 @@ def decrypt( PathParametersDict, ) - body = DecryptRequestDataDict(encrypted=data, mode=mode, iv=iv) + body = DecryptRequestDataDict(encrypted=data, mode=mode.value, iv=iv) if len(iv) == 0: - body = DecryptRequestDataDict(encrypted=data, mode=mode) + body = DecryptRequestDataDict(encrypted=data, mode=mode.value) path_params = PathParametersDict(KeyID=key_id) try: @@ -1485,17 +1502,7 @@ def sign( self, key_id: str, data: str, - mode: Literal[ - "PKCS1", - "PSS_MD5", - "PSS_SHA1", - "PSS_SHA224", - "PSS_SHA256", - "PSS_SHA384", - "PSS_SHA512", - "EdDSA", - "ECDSA", - ], + mode: SignMode, ) -> str: from .client.components.schema.sign_request_data import SignRequestDataDict from .client.paths.keys_key_id_sign.post.path_parameters import ( @@ -1503,7 +1510,7 @@ def sign( ) path_params = PathParametersDict(KeyID=key_id) - body = SignRequestDataDict(message=data, mode=mode) + body = SignRequestDataDict(message=data, mode=mode.value) try: response = self.get_api().keys_key_id_sign_post( path_params=path_params, body=body diff --git a/tests/conftest.py b/tests/conftest.py index 172aa6d..c60a6e6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,21 @@ from os import environ from typing import Literal +from nethsm import ( + DecryptMode, + KeyMechanism, + KeyType, + LogLevel, + Role, + UnattendedBootStatus, +) + @dataclass class UserData: user_id: str real_name: str - role: Literal["Administrator", "Operator", "Metrics", "Backup"] + role: Role class Constants: @@ -69,24 +78,24 @@ class Constants: PORT = 514 NETMASK = "255.255.255.0" GATEWAY = "0.0.0.0" - UNATTENDED_BOOT_OFF: Literal["off"] = "off" - UNATTENDED_BOOT_ON: Literal["on"] = "on" - LOG_LEVEL: Literal["info"] = "info" + UNATTENDED_BOOT_OFF = UnattendedBootStatus.OFF + UNATTENDED_BOOT_ON = UnattendedBootStatus.ON + LOG_LEVEL = LogLevel.INFO # test_nethsm_keys - TYPE: Literal["RSA"] = "RSA" + TYPE = KeyType.RSA MECHANISM = [ - "RSA_Signature_PKCS1", - "RSA_Decryption_PKCS1", - "RSA_Signature_PSS_SHA256", - "RSA_Decryption_OAEP_SHA256", + KeyMechanism.RSA_SIGNATURE_PKCS1, + KeyMechanism.RSA_DECRYPTION_PKCS1, + KeyMechanism.RSA_SIGNATURE_PSS_SHA256, + KeyMechanism.RSA_DECRYPTION_OAEP_SHA256, ] LENGTH = 1024 KEY_ID_ADDED = "KeyIdAdded" KEY_ID_GENERATED = "KeyIdGenerated" KEY_ID_AES = "KeyIdAES" DATA = "Test data 123456" - MODE: Literal["PKCS1"] = "PKCS1" + MODE = DecryptMode.PKCS1 # 'PKCS1', 'PSS_MD5', 'PSS_SHA1', 'PSS_SHA224', 'PSS_SHA256', 'PSS_SHA384', 'PSS_SHA512', 'EdDSA', 'ECDSA' # test_nethsm_users, test_nethsm_keys TAG1 = "Frankfurt" @@ -94,15 +103,17 @@ class Constants: TAG3 = "Teltow" TAGS = [TAG1, TAG2, TAG3] - ADMIN_USER = UserData(user_id="admin", real_name="admin", role="Administrator") + ADMIN_USER = UserData(user_id="admin", real_name="admin", role=Role.ADMINISTRATOR) ADMINISTRATOR_USER = UserData( - user_id="UIAdministrator", real_name="RNAdministrator", role="Administrator" + user_id="UIAdministrator", real_name="RNAdministrator", role=Role.ADMINISTRATOR ) OPERATOR_USER = UserData( - user_id="UIOperator", real_name="RNOperator", role="Operator" + user_id="UIOperator", real_name="RNOperator", role=Role.OPERATOR + ) + METRICS_USER = UserData( + user_id="UIMetrics", real_name="RNMetrics", role=Role.METRICS ) - METRICS_USER = UserData(user_id="UIMetrics", real_name="RNMetrics", role="Metrics") - BACKUP_USER = UserData(user_id="UIBackup", real_name="RNBackup", role="Backup") + BACKUP_USER = UserData(user_id="UIBackup", real_name="RNBackup", role=Role.BACKUP) DETAILS = "" USERS_LIST = [ diff --git a/tests/test_nethsm_config.py b/tests/test_nethsm_config.py index ca0eb3d..4f23dc8 100644 --- a/tests/test_nethsm_config.py +++ b/tests/test_nethsm_config.py @@ -6,7 +6,7 @@ from utilities import lock, nethsm, self_sign_csr, unlock # noqa: F401 import nethsm as nethsm_module -from nethsm import NetHSM +from nethsm import NetHSM, TlsKeyType """########## Preparation for the Tests ########## @@ -21,7 +21,7 @@ def get_config_logging(nethsm: NetHSM) -> None: data = nethsm.get_config_logging() assert data.ip_address == C.IP_ADDRESS_LOGGING assert data.port == C.PORT - assert data.log_level.value == C.LOG_LEVEL + assert data.log_level == C.LOG_LEVEL def get_config_network(nethsm: NetHSM) -> None: @@ -81,7 +81,7 @@ def test_set_certificate(nethsm: NetHSM) -> None: def generate_tls_key(nethsm: NetHSM) -> None: - nethsm.generate_tls_key("RSA", 2048) + nethsm.generate_tls_key(TlsKeyType.RSA, 2048) def test_get_config_logging(nethsm: NetHSM) -> None: @@ -124,8 +124,8 @@ def test_get_config_unattended_boot(nethsm: NetHSM) -> None: role.""" unattended_boot = nethsm.get_config_unattended_boot() assert ( - str(unattended_boot) == C.UNATTENDED_BOOT_OFF - or str(unattended_boot) == C.UNATTENDED_BOOT_ON + unattended_boot == C.UNATTENDED_BOOT_OFF.value + or unattended_boot == C.UNATTENDED_BOOT_ON.value ) @@ -219,19 +219,19 @@ def test_set_get_unattended_boot(nethsm: NetHSM) -> None: This command requires authentication as a user with the Administrator role.""" unattended_boot = nethsm.get_config_unattended_boot() - if str(unattended_boot) == C.UNATTENDED_BOOT_OFF: + if unattended_boot == C.UNATTENDED_BOOT_OFF.value: nethsm.set_unattended_boot(C.UNATTENDED_BOOT_ON) - assert str(nethsm.get_config_unattended_boot()) == C.UNATTENDED_BOOT_ON + assert nethsm.get_config_unattended_boot() == C.UNATTENDED_BOOT_ON.value nethsm.set_unattended_boot(C.UNATTENDED_BOOT_OFF) - assert str(nethsm.get_config_unattended_boot()) == C.UNATTENDED_BOOT_OFF + assert nethsm.get_config_unattended_boot() == C.UNATTENDED_BOOT_OFF.value - if str(unattended_boot) == C.UNATTENDED_BOOT_ON: + if unattended_boot == C.UNATTENDED_BOOT_ON.value: nethsm.set_unattended_boot(C.UNATTENDED_BOOT_OFF) - assert str(nethsm.get_config_unattended_boot()) == C.UNATTENDED_BOOT_OFF + assert nethsm.get_config_unattended_boot() == C.UNATTENDED_BOOT_OFF.value nethsm.set_unattended_boot(C.UNATTENDED_BOOT_ON) - assert str(nethsm.get_config_unattended_boot()) == C.UNATTENDED_BOOT_ON + assert nethsm.get_config_unattended_boot() == C.UNATTENDED_BOOT_ON.value def test_set_unlock_passphrase_lock_unlock(nethsm: NetHSM) -> None: diff --git a/tests/test_nethsm_keys.py b/tests/test_nethsm_keys.py index 9dee281..974408b 100644 --- a/tests/test_nethsm_keys.py +++ b/tests/test_nethsm_keys.py @@ -15,7 +15,7 @@ ) import nethsm as nethsm_module -from nethsm import NetHSM +from nethsm import DecryptMode, EncryptMode, KeyMechanism, KeyType, NetHSM, SignMode """########## Preparation for the Tests ########## @@ -42,7 +42,7 @@ def add_key(nethsm: NetHSM) -> None: nethsm.add_key( key_id=C.KEY_ID_ADDED, type=C.TYPE, - mechanisms=C.MECHANISM, # type: ignore + mechanisms=C.MECHANISM, prime_p=p, prime_q=q, public_exponent=e, @@ -64,8 +64,8 @@ def generate_key_aes(nethsm: NetHSM) -> None: nethsm.generate_key( key_id=C.KEY_ID_AES, - type="Generic", - mechanisms=["AES_Encryption_CBC", "AES_Decryption_CBC"], # type: ignore + type=KeyType.GENERIC, + mechanisms=[KeyMechanism.AES_ENCRYPTION_CBC, KeyMechanism.AES_DECRYPTION_CBC], length=256, ) @@ -76,7 +76,7 @@ def generate_key(nethsm: NetHSM) -> None: This command requires authentication as a user with the Administrator or Operator role.""" try: - nethsm.generate_key(C.TYPE, C.MECHANISM, C.LENGTH, C.KEY_ID_GENERATED) # type: ignore + nethsm.generate_key(C.TYPE, C.MECHANISM, C.LENGTH, C.KEY_ID_GENERATED) except nethsm_module.NetHSMError: pass @@ -137,7 +137,7 @@ def test_generate_get_key_by_id(nethsm: nethsm_module.NetHSM) -> None: key = nethsm.get_key(C.KEY_ID_GENERATED) # mechanisms = ", ".join(key.mechanisms) Todo: test with multiple mech. - assert key.type.value == C.TYPE + assert key.type == C.TYPE for mechanism in key.mechanisms: assert mechanism in C.MECHANISM assert key.operations >= 0 @@ -194,7 +194,7 @@ def test_list_get_keys(nethsm: nethsm_module.NetHSM) -> None: key_ids = nethsm.list_keys(None) for key_id in key_ids: key = nethsm.get_key(key_id=key_id) - assert key.type.value == C.TYPE + assert key.type == C.TYPE for mechanism in key.mechanisms: assert mechanism in C.MECHANISM assert key.operations >= 0 @@ -276,7 +276,7 @@ def test_sign(nethsm: NetHSM) -> None: # mit dem privaten schlüssel signieren signature = nethsm.sign( C.KEY_ID_GENERATED, base64.b64encode(hash_object.digest()).decode(), - "PSS_SHA256", + SignMode.PSS_SHA256, ) print(signature) verify_rsa_signature(key, hash_object, base64.b64decode(signature)) @@ -317,13 +317,13 @@ def test_encrypt_decrypt(nethsm: NetHSM) -> None: encrypted = nethsm.encrypt( C.KEY_ID_AES, data_b64, - "AES_CBC", + EncryptMode.AES_CBC, iv_b64, ) decrypt = nethsm.decrypt( C.KEY_ID_AES, encrypted[0], - "AES_CBC", + DecryptMode.AES_CBC, iv_b64, ) assert decrypt == data_b64 diff --git a/tests/test_nethsm_users.py b/tests/test_nethsm_users.py index f5a998a..ba9dd74 100644 --- a/tests/test_nethsm_users.py +++ b/tests/test_nethsm_users.py @@ -76,7 +76,7 @@ def test_list_get_delete_add_users(nethsm: NetHSM) -> None: for i in range(len(remaining)): if user.user_id == remaining[i].user_id: assert user.real_name == remaining[i].real_name - assert user.role.value == remaining[i].role + assert user.role.value == remaining[i].role.value remaining.pop(i) break @@ -91,7 +91,7 @@ def test_get_user_admin(nethsm: NetHSM) -> None: user = nethsm.get_user(user_id=C.ADMIN_USER.user_id) assert user.user_id == C.ADMIN_USER.user_id assert user.real_name == C.ADMIN_USER.real_name - assert user.role.value == C.ADMIN_USER.role + assert user.role.value == C.ADMIN_USER.role.value # @pytest.mark.xfail(reason="connect() doesn't require correct passphrase yet")