Skip to content

Commit

Permalink
Introduce custom type for Base64-encoded data
Browse files Browse the repository at this point in the history
This patch introduces a Base64 class that is used for base64-encoded
data as input or output.  This makes the encoding more explicit and
easier to use.

Fixes: #50
  • Loading branch information
robin-nitrokey committed Nov 26, 2023
1 parent 17a04fa commit 4d9736d
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 73 deletions.
96 changes: 63 additions & 33 deletions nethsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

__version__ = "0.5.0"

import binascii
import contextlib
import enum
import json
from base64 import b64encode
from base64 import b64decode, b64encode
from dataclasses import dataclass
from datetime import datetime, timezone
from io import BufferedReader, FileIO
Expand Down Expand Up @@ -196,6 +197,28 @@ def from_string(s: str) -> "TlsKeyType":
raise ValueError(f"Unsupported TLS key type {s}")


@dataclass
class Base64:
data: str

def decode(self) -> bytes:
return b64decode(self.data)

@classmethod
def from_encoded(cls, data: Union[bytes, str]) -> "Base64":
try:
b64decode(data, validate=True)
if isinstance(data, bytes):
data = data.decode()

Check warning on line 212 in nethsm/__init__.py

View check run for this annotation

Codecov / codecov/patch

nethsm/__init__.py#L212

Added line #L212 was not covered by tests
return cls(data=data)
except binascii.Error:
raise ValueError(f"Invalid base64 data: {data!r}")

Check warning on line 215 in nethsm/__init__.py

View check run for this annotation

Codecov / codecov/patch

nethsm/__init__.py#L214-L215

Added lines #L214 - L215 were not covered by tests

@classmethod
def encode(cls, data: bytes) -> "Base64":
return cls(data=b64encode(data).decode())


@dataclass
class Authentication:
username: str
Expand Down Expand Up @@ -225,28 +248,28 @@ class User:

@dataclass
class RsaPublicKey:
modulus: str
public_exponent: str
modulus: Base64
public_exponent: Base64


@dataclass
class EcPublicKey:
data: str
data: Base64


PublicKey = Union[RsaPublicKey, EcPublicKey, None]


@dataclass
class RsaPrivateKey:
prime_p: str
prime_q: str
public_exponent: str
prime_p: Base64
prime_q: Base64
public_exponent: Base64


@dataclass
class GenericPrivateKey:
data: str
data: Base64


PrivateKey = Union[RsaPrivateKey, GenericPrivateKey]
Expand All @@ -264,8 +287,8 @@ class Key:

@dataclass
class EncryptionResult:
encrypted: str
iv: str
encrypted: Base64
iv: Base64


@dataclass
Expand Down Expand Up @@ -765,7 +788,7 @@ def get_state(self) -> State:
_handle_exception(e)
return State.from_string(response.body.state)

def get_random_data(self, n: int) -> str:
def get_random_data(self, n: int) -> Base64:
from .client.components.schema.random_request_data import RandomRequestDataDict

body = RandomRequestDataDict(length=n)
Expand All @@ -780,7 +803,7 @@ def get_random_data(self, n: int) -> str:
400: "Invalid length. Must be between 1 and 1024",
},
)
return response.body.random
return Base64.from_encoded(response.body.random)

def get_metrics(self) -> Mapping[str, Any]:
try:
Expand Down Expand Up @@ -840,7 +863,8 @@ def get_key(self, key_id: str) -> Key:
assert not isinstance(key.public.modulus, Unset)
assert not isinstance(key.public.publicExponent, Unset)
public_key = RsaPublicKey(
modulus=key.public.modulus, public_exponent=key.public.publicExponent
modulus=Base64.from_encoded(key.public.modulus),
public_exponent=Base64.from_encoded(key.public.publicExponent),
)
elif key_type == KeyType.GENERIC:
if not isinstance(key.public, Unset):
Expand All @@ -853,7 +877,7 @@ def get_key(self, key_id: str) -> Key:
assert not isinstance(key.public.data, Unset)
assert isinstance(key.public.modulus, Unset)
assert isinstance(key.public.publicExponent, Unset)
public_key = EcPublicKey(data=key.public.data)
public_key = EcPublicKey(data=Base64.from_encoded(key.public.data))

Check warning on line 880 in nethsm/__init__.py

View check run for this annotation

Codecov / codecov/patch

nethsm/__init__.py#L880

Added line #L880 was not covered by tests

return Key(
key_id=key_id,
Expand Down Expand Up @@ -903,13 +927,13 @@ def add_key(
if type == KeyType.RSA:
assert isinstance(private_key, RsaPrivateKey)
key_data = KeyPrivateDataDict(
primeP=private_key.prime_p,
primeQ=private_key.prime_q,
publicExponent=private_key.public_exponent,
primeP=private_key.prime_p.data,
primeQ=private_key.prime_q.data,
publicExponent=private_key.public_exponent.data,
)
else:
assert isinstance(private_key, GenericPrivateKey)
key_data = KeyPrivateDataDict(data=private_key.data)
key_data = KeyPrivateDataDict(data=private_key.data.data)

Check warning on line 936 in nethsm/__init__.py

View check run for this annotation

Codecov / codecov/patch

nethsm/__init__.py#L936

Added line #L936 was not covered by tests

mechanism_tuple: KeyMechanismsTupleInput = [
mechanism.value for mechanism in mechanisms
Expand Down Expand Up @@ -1485,17 +1509,20 @@ def factory_reset(self) -> None:
)

def encrypt(
self, key_id: str, data: str, mode: EncryptMode, iv: str
self, key_id: str, data: Base64, mode: EncryptMode, iv: Optional[Base64] = None
) -> EncryptionResult:
from .client.components.schema.encrypt_request_data import (
EncryptRequestDataDict,
)
from .client.paths.keys_key_id_encrypt.post.path_parameters import (
PathParametersDict,
)
from .client.schemas import Unset

path_params = PathParametersDict(KeyID=key_id)
body = EncryptRequestDataDict(message=data, mode=mode.value, iv=iv)
body = EncryptRequestDataDict(
message=data.data, mode=mode.value, iv=iv.data if iv else Unset()
)
try:
response = self._get_api().keys_key_id_encrypt_post(
path_params=path_params, body=body
Expand All @@ -1510,26 +1537,29 @@ def encrypt(
404: f"Key {key_id} not found",
},
)
return EncryptionResult(encrypted=response.body.encrypted, iv=response.body.iv)
return EncryptionResult(
encrypted=Base64.from_encoded(response.body.encrypted),
iv=Base64.from_encoded(response.body.iv),
)

def decrypt(
self,
key_id: str,
data: str,
data: Base64,
mode: DecryptMode,
iv: str,
) -> str:
iv: Optional[Base64] = None,
) -> Base64:
from .client.components.schema.decrypt_request_data import (
DecryptRequestDataDict,
)
from .client.paths.keys_key_id_decrypt.post.path_parameters import (
PathParametersDict,
)
from .client.schemas import Unset

body = DecryptRequestDataDict(encrypted=data, mode=mode.value, iv=iv)

if len(iv) == 0:
body = DecryptRequestDataDict(encrypted=data, mode=mode.value)
body = DecryptRequestDataDict(
encrypted=data.data, mode=mode.value, iv=iv.data if iv else Unset()
)

path_params = PathParametersDict(KeyID=key_id)
try:
Expand All @@ -1546,21 +1576,21 @@ def decrypt(
404: f"Key {key_id} not found",
},
)
return response.body.decrypted
return Base64.from_encoded(response.body.decrypted)

def sign(
self,
key_id: str,
data: str,
data: Base64,
mode: SignMode,
) -> str:
) -> Base64:
from .client.components.schema.sign_request_data import SignRequestDataDict
from .client.paths.keys_key_id_sign.post.path_parameters import (
PathParametersDict,
)

path_params = PathParametersDict(KeyID=key_id)
body = SignRequestDataDict(message=data, mode=mode.value)
body = SignRequestDataDict(message=data.data, mode=mode.value)
try:
response = self._get_api().keys_key_id_sign_post(
path_params=path_params, body=body
Expand All @@ -1575,7 +1605,7 @@ def sign(
404: f"Key {key_id} not found",
},
)
return response.body.signature
return Base64.from_encoded(response.body.signature)


@contextlib.contextmanager
Expand Down
22 changes: 10 additions & 12 deletions tests/test_nethsm_keys.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import base64

import pytest
from conftest import Constants as C
from Crypto import Random
Expand All @@ -16,6 +14,7 @@

import nethsm as nethsm_module
from nethsm import (
Base64,
DecryptMode,
EncryptMode,
KeyMechanism,
Expand Down Expand Up @@ -46,13 +45,13 @@ def add_key(nethsm: NetHSM) -> None:
if C.KEY_ID_ADDED in nethsm.list_keys(None):
nethsm.delete_key(C.KEY_ID_ADDED)

p, q, e = generate_rsa_key_pair(1024)
private_key = generate_rsa_key_pair(1024)

nethsm.add_key(
key_id=C.KEY_ID_ADDED,
type=C.TYPE,
mechanisms=C.MECHANISM,
private_key=RsaPrivateKey(prime_p=p, prime_q=q, public_exponent=e),
private_key=private_key,
)


Expand Down Expand Up @@ -272,11 +271,11 @@ def test_sign(nethsm: NetHSM) -> None: # mit dem privaten schlüssel signieren
with connect(C.OPERATOR_USER) as nethsm:
signature = nethsm.sign(
C.KEY_ID_GENERATED,
base64.b64encode(hash_object.digest()).decode(),
Base64.encode(hash_object.digest()),
SignMode.PSS_SHA256,
)
print(signature)
verify_rsa_signature(key, hash_object, base64.b64decode(signature))
print(signature.data)
verify_rsa_signature(key, hash_object, signature.decode())


def test_decrypt(nethsm: NetHSM) -> None:
Expand All @@ -292,11 +291,10 @@ def test_decrypt(nethsm: NetHSM) -> None:
with connect(C.OPERATOR_USER) as nethsm:
decrypt = nethsm.decrypt(
C.KEY_ID_GENERATED,
base64.b64encode(encrypted).decode(),
Base64.encode(encrypted),
C.MODE,
"arstasrta",
)
assert base64.b64decode(decrypt).decode() == C.DATA
assert decrypt.decode().decode() == C.DATA


def test_encrypt_decrypt(nethsm: NetHSM) -> None:
Expand All @@ -305,9 +303,9 @@ def test_encrypt_decrypt(nethsm: NetHSM) -> None:
add_user(nethsm, C.OPERATOR_USER)
IV = Random.new().read(AES.block_size)

iv_b64 = base64.b64encode(IV).decode()
iv_b64 = Base64.encode(IV)

data_b64 = base64.b64encode(C.DATA.encode()).decode()
data_b64 = Base64.encode(C.DATA.encode())

with connect(C.OPERATOR_USER) as nethsm:

Expand Down
21 changes: 8 additions & 13 deletions tests/test_nethsm_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,11 @@ def test_state_provision_add_user_get_random_data(nethsm_no_provision: NetHSM) -
random_data1 = nethsm.get_random_data(100)
random_data2 = nethsm.get_random_data(100)
random_data3 = nethsm.get_random_data(100)
assert (
len(str(random_data1)) == 136
and len(str(random_data2)) == 136
and len(str(random_data3)) == 136
)
assert (
random_data1 != random_data2
and random_data1 != random_data3
and random_data2 != random_data3
)
# Todo: check if decoded function is the same length as given
# assert len(base64.b64decode(bytes(nethsm.get_random_data(100)))) ==
# 100

assert len(random_data1.decode()) == 100
assert len(random_data2.decode()) == 100
assert len(random_data3.decode()) == 100

assert random_data1 != random_data2
assert random_data1 != random_data3
assert random_data2 != random_data3
8 changes: 3 additions & 5 deletions tests/test_nethsm_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import base64
import datetime
import os

Expand All @@ -16,7 +15,7 @@
update,
)

from nethsm import NetHSM, NetHSMError
from nethsm import Base64, NetHSM, NetHSMError
from nethsm.backup import Backup, EncryptedBackup

"""######################### Preparation for the Tests #########################
Expand Down Expand Up @@ -131,11 +130,10 @@ def test_state_restore(nethsm: NetHSM) -> None:
with connect(C.OPERATOR_USER) as nethsm:
decrypt = nethsm.decrypt(
C.KEY_ID_GENERATED,
base64.b64encode(encrypted).decode(),
Base64.encode(encrypted),
C.MODE,
"arstasrta",
)
assert base64.b64decode(decrypt).decode() == C.DATA
assert decrypt.decode().decode() == C.DATA


def test_state_provision_update(nethsm: NetHSM) -> None:
Expand Down
Loading

0 comments on commit 4d9736d

Please sign in to comment.