diff --git a/nethsm/__init__.py b/nethsm/__init__.py index e225417..556090c 100644 --- a/nethsm/__init__.py +++ b/nethsm/__init__.py @@ -10,10 +10,11 @@ __version__ = "0.5.0" +import binascii import contextlib import enum import json -from base64 import b64encode +from base64 import b64decode, b64encode from dataclasses import dataclass from datetime import datetime, timezone from io import BufferedReader, FileIO @@ -196,6 +197,28 @@ def from_string(s: str) -> "TlsKeyType": raise ValueError(f"Unsupported TLS key type {s}") +@dataclass +class Base64: + data: str + + def decode(self) -> bytes: + return b64decode(self.data) + + @classmethod + def from_encoded(cls, data: Union[bytes, str]) -> "Base64": + try: + b64decode(data, validate=True) + if isinstance(data, bytes): + data = data.decode() + return cls(data=data) + except binascii.Error: + raise ValueError(f"Invalid base64 data: {data!r}") + + @classmethod + def encode(cls, data: bytes) -> "Base64": + return cls(data=b64encode(data).decode()) + + @dataclass class Authentication: username: str @@ -225,13 +248,13 @@ class User: @dataclass class RsaPublicKey: - modulus: str - public_exponent: str + modulus: Base64 + public_exponent: Base64 @dataclass class EcPublicKey: - data: str + data: Base64 PublicKey = Union[RsaPublicKey, EcPublicKey, None] @@ -239,14 +262,14 @@ class EcPublicKey: @dataclass class RsaPrivateKey: - prime_p: str - prime_q: str - public_exponent: str + prime_p: Base64 + prime_q: Base64 + public_exponent: Base64 @dataclass class GenericPrivateKey: - data: str + data: Base64 PrivateKey = Union[RsaPrivateKey, GenericPrivateKey] @@ -264,8 +287,8 @@ class Key: @dataclass class EncryptionResult: - encrypted: str - iv: str + encrypted: Base64 + iv: Base64 @dataclass @@ -765,7 +788,7 @@ def get_state(self) -> State: _handle_exception(e) return State.from_string(response.body.state) - def get_random_data(self, n: int) -> str: + def get_random_data(self, n: int) -> Base64: from .client.components.schema.random_request_data import RandomRequestDataDict body = RandomRequestDataDict(length=n) @@ -780,7 +803,7 @@ def get_random_data(self, n: int) -> str: 400: "Invalid length. Must be between 1 and 1024", }, ) - return response.body.random + return Base64.from_encoded(response.body.random) def get_metrics(self) -> Mapping[str, Any]: try: @@ -840,7 +863,8 @@ def get_key(self, key_id: str) -> Key: assert not isinstance(key.public.modulus, Unset) assert not isinstance(key.public.publicExponent, Unset) public_key = RsaPublicKey( - modulus=key.public.modulus, public_exponent=key.public.publicExponent + modulus=Base64.from_encoded(key.public.modulus), + public_exponent=Base64.from_encoded(key.public.publicExponent), ) elif key_type == KeyType.GENERIC: if not isinstance(key.public, Unset): @@ -853,7 +877,7 @@ def get_key(self, key_id: str) -> Key: assert not isinstance(key.public.data, Unset) assert isinstance(key.public.modulus, Unset) assert isinstance(key.public.publicExponent, Unset) - public_key = EcPublicKey(data=key.public.data) + public_key = EcPublicKey(data=Base64.from_encoded(key.public.data)) return Key( key_id=key_id, @@ -903,13 +927,13 @@ def add_key( if type == KeyType.RSA: assert isinstance(private_key, RsaPrivateKey) key_data = KeyPrivateDataDict( - primeP=private_key.prime_p, - primeQ=private_key.prime_q, - publicExponent=private_key.public_exponent, + primeP=private_key.prime_p.data, + primeQ=private_key.prime_q.data, + publicExponent=private_key.public_exponent.data, ) else: assert isinstance(private_key, GenericPrivateKey) - key_data = KeyPrivateDataDict(data=private_key.data) + key_data = KeyPrivateDataDict(data=private_key.data.data) mechanism_tuple: KeyMechanismsTupleInput = [ mechanism.value for mechanism in mechanisms @@ -1485,7 +1509,7 @@ def factory_reset(self) -> None: ) def encrypt( - self, key_id: str, data: str, mode: EncryptMode, iv: str + self, key_id: str, data: Base64, mode: EncryptMode, iv: Optional[Base64] = None ) -> EncryptionResult: from .client.components.schema.encrypt_request_data import ( EncryptRequestDataDict, @@ -1493,9 +1517,12 @@ def encrypt( from .client.paths.keys_key_id_encrypt.post.path_parameters import ( PathParametersDict, ) + from .client.schemas import Unset path_params = PathParametersDict(KeyID=key_id) - body = EncryptRequestDataDict(message=data, mode=mode.value, iv=iv) + body = EncryptRequestDataDict( + message=data.data, mode=mode.value, iv=iv.data if iv else Unset() + ) try: response = self._get_api().keys_key_id_encrypt_post( path_params=path_params, body=body @@ -1510,26 +1537,29 @@ def encrypt( 404: f"Key {key_id} not found", }, ) - return EncryptionResult(encrypted=response.body.encrypted, iv=response.body.iv) + return EncryptionResult( + encrypted=Base64.from_encoded(response.body.encrypted), + iv=Base64.from_encoded(response.body.iv), + ) def decrypt( self, key_id: str, - data: str, + data: Base64, mode: DecryptMode, - iv: str, - ) -> str: + iv: Optional[Base64] = None, + ) -> Base64: from .client.components.schema.decrypt_request_data import ( DecryptRequestDataDict, ) from .client.paths.keys_key_id_decrypt.post.path_parameters import ( PathParametersDict, ) + from .client.schemas import Unset - body = DecryptRequestDataDict(encrypted=data, mode=mode.value, iv=iv) - - if len(iv) == 0: - body = DecryptRequestDataDict(encrypted=data, mode=mode.value) + body = DecryptRequestDataDict( + encrypted=data.data, mode=mode.value, iv=iv.data if iv else Unset() + ) path_params = PathParametersDict(KeyID=key_id) try: @@ -1546,21 +1576,21 @@ def decrypt( 404: f"Key {key_id} not found", }, ) - return response.body.decrypted + return Base64.from_encoded(response.body.decrypted) def sign( self, key_id: str, - data: str, + data: Base64, mode: SignMode, - ) -> str: + ) -> Base64: from .client.components.schema.sign_request_data import SignRequestDataDict from .client.paths.keys_key_id_sign.post.path_parameters import ( PathParametersDict, ) path_params = PathParametersDict(KeyID=key_id) - body = SignRequestDataDict(message=data, mode=mode.value) + body = SignRequestDataDict(message=data.data, mode=mode.value) try: response = self._get_api().keys_key_id_sign_post( path_params=path_params, body=body @@ -1575,7 +1605,7 @@ def sign( 404: f"Key {key_id} not found", }, ) - return response.body.signature + return Base64.from_encoded(response.body.signature) @contextlib.contextmanager diff --git a/tests/test_nethsm_keys.py b/tests/test_nethsm_keys.py index dea91cc..331408e 100644 --- a/tests/test_nethsm_keys.py +++ b/tests/test_nethsm_keys.py @@ -1,5 +1,3 @@ -import base64 - import pytest from conftest import Constants as C from Crypto import Random @@ -16,6 +14,7 @@ import nethsm as nethsm_module from nethsm import ( + Base64, DecryptMode, EncryptMode, KeyMechanism, @@ -46,13 +45,13 @@ def add_key(nethsm: NetHSM) -> None: if C.KEY_ID_ADDED in nethsm.list_keys(None): nethsm.delete_key(C.KEY_ID_ADDED) - p, q, e = generate_rsa_key_pair(1024) + private_key = generate_rsa_key_pair(1024) nethsm.add_key( key_id=C.KEY_ID_ADDED, type=C.TYPE, mechanisms=C.MECHANISM, - private_key=RsaPrivateKey(prime_p=p, prime_q=q, public_exponent=e), + private_key=private_key, ) @@ -272,11 +271,11 @@ def test_sign(nethsm: NetHSM) -> None: # mit dem privaten schlüssel signieren with connect(C.OPERATOR_USER) as nethsm: signature = nethsm.sign( C.KEY_ID_GENERATED, - base64.b64encode(hash_object.digest()).decode(), + Base64.encode(hash_object.digest()), SignMode.PSS_SHA256, ) - print(signature) - verify_rsa_signature(key, hash_object, base64.b64decode(signature)) + print(signature.data) + verify_rsa_signature(key, hash_object, signature.decode()) def test_decrypt(nethsm: NetHSM) -> None: @@ -292,11 +291,10 @@ def test_decrypt(nethsm: NetHSM) -> None: with connect(C.OPERATOR_USER) as nethsm: decrypt = nethsm.decrypt( C.KEY_ID_GENERATED, - base64.b64encode(encrypted).decode(), + Base64.encode(encrypted), C.MODE, - "arstasrta", ) - assert base64.b64decode(decrypt).decode() == C.DATA + assert decrypt.decode().decode() == C.DATA def test_encrypt_decrypt(nethsm: NetHSM) -> None: @@ -305,9 +303,9 @@ def test_encrypt_decrypt(nethsm: NetHSM) -> None: add_user(nethsm, C.OPERATOR_USER) IV = Random.new().read(AES.block_size) - iv_b64 = base64.b64encode(IV).decode() + iv_b64 = Base64.encode(IV) - data_b64 = base64.b64encode(C.DATA.encode()).decode() + data_b64 = Base64.encode(C.DATA.encode()) with connect(C.OPERATOR_USER) as nethsm: diff --git a/tests/test_nethsm_other.py b/tests/test_nethsm_other.py index 56bec94..dd673c0 100644 --- a/tests/test_nethsm_other.py +++ b/tests/test_nethsm_other.py @@ -147,16 +147,11 @@ def test_state_provision_add_user_get_random_data(nethsm_no_provision: NetHSM) - random_data1 = nethsm.get_random_data(100) random_data2 = nethsm.get_random_data(100) random_data3 = nethsm.get_random_data(100) - assert ( - len(str(random_data1)) == 136 - and len(str(random_data2)) == 136 - and len(str(random_data3)) == 136 - ) - assert ( - random_data1 != random_data2 - and random_data1 != random_data3 - and random_data2 != random_data3 - ) - # Todo: check if decoded function is the same length as given - # assert len(base64.b64decode(bytes(nethsm.get_random_data(100)))) == - # 100 + + assert len(random_data1.decode()) == 100 + assert len(random_data2.decode()) == 100 + assert len(random_data3.decode()) == 100 + + assert random_data1 != random_data2 + assert random_data1 != random_data3 + assert random_data2 != random_data3 diff --git a/tests/test_nethsm_system.py b/tests/test_nethsm_system.py index b359671..a0e647b 100644 --- a/tests/test_nethsm_system.py +++ b/tests/test_nethsm_system.py @@ -1,4 +1,3 @@ -import base64 import datetime import os @@ -16,7 +15,7 @@ update, ) -from nethsm import NetHSM, NetHSMError +from nethsm import Base64, NetHSM, NetHSMError from nethsm.backup import Backup, EncryptedBackup """######################### Preparation for the Tests ######################### @@ -131,11 +130,10 @@ def test_state_restore(nethsm: NetHSM) -> None: with connect(C.OPERATOR_USER) as nethsm: decrypt = nethsm.decrypt( C.KEY_ID_GENERATED, - base64.b64encode(encrypted).decode(), + Base64.encode(encrypted), C.MODE, - "arstasrta", ) - assert base64.b64decode(decrypt).decode() == C.DATA + assert decrypt.decode().decode() == C.DATA def test_state_provision_update(nethsm: NetHSM) -> None: diff --git a/tests/utilities.py b/tests/utilities.py index 39a32a6..e933304 100644 --- a/tests/utilities.py +++ b/tests/utilities.py @@ -1,4 +1,3 @@ -import base64 import contextlib import datetime import os @@ -22,7 +21,7 @@ from cryptography.x509.oid import NameOID import nethsm as nethsm_module -from nethsm import Authentication, NetHSM +from nethsm import Authentication, Base64, NetHSM, RsaPrivateKey @pytest.fixture(scope="module") @@ -172,17 +171,18 @@ def add_user(nethsm: NetHSM, user: UserData) -> None: nethsm.add_user(user.real_name, user.role, C.PASSPHRASE, user.user_id) -def generate_rsa_key_pair(length_in_bit: int) -> tuple[str, str, str]: +def generate_rsa_key_pair(length_in_bit: int) -> RsaPrivateKey: key_pair = RSA.generate(length_in_bit) length_in_byte = int(length_in_bit / 8) # "big" byteorder is needed, it's the dominant order in networking - p = base64.b64encode(key_pair.p.to_bytes(length_in_byte, "big")) - q = base64.b64encode(key_pair.q.to_bytes(length_in_byte, "big")) - e = base64.b64encode(key_pair.e.to_bytes(length_in_byte, "big")) - ps = str(p, "utf-8").strip() - qs = str(q, "utf-8").strip() - es = str(e, "utf-8").strip() - return ps, qs, es + p = key_pair.p.to_bytes(length_in_byte, "big") + q = key_pair.q.to_bytes(length_in_byte, "big") + e = key_pair.e.to_bytes(length_in_byte, "big") + return RsaPrivateKey( + prime_p=Base64.encode(p), + prime_q=Base64.encode(q), + public_exponent=Base64.encode(e), + ) def verify_rsa_signature(