Skip to content

Commit

Permalink
Use enums instead of literals
Browse files Browse the repository at this point in the history
Using enums makes it easier to use the functions in a typesafe way and
to validate user input.

Fixes: #69
  • Loading branch information
robin-nitrokey committed Nov 17, 2023
1 parent ff2e185 commit 23eb940
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 117 deletions.
165 changes: 86 additions & 79 deletions nethsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from datetime import datetime
from io import BufferedReader, FileIO
from typing import TYPE_CHECKING, Any, Iterator, Literal, Mapping, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Iterator, Mapping, Optional, Union, cast
from urllib.parse import urlencode

import urllib3
Expand Down Expand Up @@ -79,6 +79,13 @@ class UnattendedBootStatus(enum.Enum):
ON = "on"
OFF = "off"

@staticmethod
def from_string(s: str) -> "UnattendedBootStatus":
for status in UnattendedBootStatus:
if status.value == s:
return status
raise ValueError(f"Unsupported unattended boot status {s}")


class KeyType(enum.Enum):
RSA = "RSA"
Expand All @@ -97,11 +104,6 @@ def from_string(s: str) -> "KeyType":
raise ValueError(f"Unsupported key type {s}")


KeyTypeLitteral = Literal[
"RSA", "Curve25519", "EC_P224", "EC_P256", "EC_P384", "EC_P521", "Generic"
]


class KeyMechanism(enum.Enum):
RSA_DECRYPTION_RAW = "RSA_Decryption_RAW"
RSA_DECRYPTION_PKCS1 = "RSA_Decryption_PKCS1"
Expand All @@ -123,33 +125,24 @@ class KeyMechanism(enum.Enum):
AES_ENCRYPTION_CBC = "AES_Encryption_CBC"
AES_DECRYPTION_CBC = "AES_Decryption_CBC"


KeyMechanismLiteral = Literal[
"RSA_Decryption_RAW",
"RSA_Decryption_PKCS1",
"RSA_Decryption_OAEP_MD5",
"RSA_Decryption_OAEP_SHA1",
"RSA_Decryption_OAEP_SHA224",
"RSA_Decryption_OAEP_SHA256",
"RSA_Decryption_OAEP_SHA384",
"RSA_Decryption_OAEP_SHA512",
"RSA_Signature_PKCS1",
"RSA_Signature_PSS_MD5",
"RSA_Signature_PSS_SHA1",
"RSA_Signature_PSS_SHA224",
"RSA_Signature_PSS_SHA256",
"RSA_Signature_PSS_SHA384",
"RSA_Signature_PSS_SHA512",
"EdDSA_Signature",
"ECDSA_Signature",
"AES_Encryption_CBC",
"AES_Decryption_CBC",
]
@staticmethod
def from_string(s: str) -> "KeyMechanism":
for key_mechanism in KeyMechanism:
if key_mechanism.value == s:
return key_mechanism
raise ValueError(f"Unsupported key mechanism {s}")


class EncryptMode(enum.Enum):
AES_CBC = "AES_CBC"

@staticmethod
def from_string(s: str) -> "EncryptMode":
for mode in EncryptMode:
if mode.value == s:
return mode
raise ValueError(f"Unsupported encrypt mode {s}")


class DecryptMode(enum.Enum):
RAW = "RAW"
Expand All @@ -162,6 +155,13 @@ class DecryptMode(enum.Enum):
OAEP_SHA512 = "OAEP_SHA512"
AES_CBC = "AES_CBC"

@staticmethod
def from_string(s: str) -> "DecryptMode":
for mode in DecryptMode:
if mode.value == s:
return mode
raise ValueError(f"Unsupported decrypt mode {s}")


class SignMode(enum.Enum):
PKCS1 = "PKCS1"
Expand All @@ -174,6 +174,13 @@ class SignMode(enum.Enum):
EDDSA = "EdDSA"
ECDSA = "ECDSA"

@staticmethod
def from_string(s: str) -> "SignMode":
for mode in SignMode:
if mode.value == s:
return mode
raise ValueError(f"Unsupported sign mode {s}")


class TlsKeyType(enum.Enum):
RSA = "RSA"
Expand All @@ -183,6 +190,13 @@ class TlsKeyType(enum.Enum):
EC_P384 = "EC_P384"
EC_P521 = "EC_P521"

@staticmethod
def from_string(s: str) -> "TlsKeyType":
for key_type in TlsKeyType:
if key_type.value == s:
return key_type
raise ValueError(f"Unsupported TLS key type {s}")


@dataclass
class SystemInfo:
Expand All @@ -202,7 +216,7 @@ class User:
@dataclass
class Key:
key_id: str
mechanisms: list[str]
mechanisms: list[KeyMechanism]
type: KeyType
operations: int
tags: Optional[list[str]]
Expand Down Expand Up @@ -527,7 +541,7 @@ def get_user(self, user_id: str) -> User:
def add_user(
self,
real_name: str,
role: Literal["Administrator", "Operator", "Metrics", "Backup"],
role: Role,
passphrase: str,
user_id: Optional[str] = None,
) -> str:
Expand All @@ -536,7 +550,7 @@ def add_user(

body = UserPostDataDict(
realName=real_name,
role=role,
role=role.value,
passphrase=passphrase,
)
try:
Expand Down Expand Up @@ -775,7 +789,9 @@ def get_key(self, key_id: str) -> Key:
)
return Key(
key_id=key_id,
mechanisms=[mechanism for mechanism in key.mechanisms],
mechanisms=[
KeyMechanism.from_string(mechanism) for mechanism in key.mechanisms
],
type=KeyType.from_string(key.type),
operations=key.operations,
tags=[str(tag) for tag in cast(list[str], key.restrictions["tags"])]
Expand Down Expand Up @@ -811,14 +827,15 @@ def get_key_public_key(self, key_id: str) -> str:
def add_key(
self,
key_id: str,
type: KeyTypeLitteral,
mechanisms: list[KeyMechanismLiteral],
type: KeyType,
mechanisms: list[KeyMechanism],
tags: list[str],
prime_p: Optional[str],
prime_q: Optional[str],
public_exponent: Optional[str],
data: Optional[str],
) -> str:
from .client.components.schema.key_mechanisms import KeyMechanismsTupleInput
from .client.components.schema.key_private_data import KeyPrivateDataDict
from .client.components.schema.key_restrictions import KeyRestrictionsDict
from .client.components.schema.private_key import PrivateKeyDict
Expand All @@ -827,7 +844,7 @@ def add_key(
# To do: split into different methods for RSA and other key types, or
# at least change typing accordingly

if type == "RSA":
if type == KeyType.RSA:
assert prime_p
assert prime_q
assert public_exponent
Expand All @@ -840,19 +857,23 @@ def add_key(
assert data
key_data = KeyPrivateDataDict(data=data)

mechanism_tuple: KeyMechanismsTupleInput = [
mechanism.value for mechanism in mechanisms
]

if tags:
body = PrivateKeyDict(
type=type,
mechanisms=mechanisms,
type=type.value,
mechanisms=mechanism_tuple,
key=key_data,
restrictions=KeyRestrictionsDict(
tags=TagListTuple([tag for tag in tags])
),
)
else:
body = PrivateKeyDict(
type=type,
mechanisms=mechanisms,
type=type.value,
mechanisms=mechanism_tuple,
key=key_data,
)

Expand Down Expand Up @@ -905,27 +926,31 @@ def delete_key(self, key_id: str) -> None:

def generate_key(
self,
type: KeyTypeLitteral,
mechanisms: tuple[KeyMechanismLiteral],
type: KeyType,
mechanisms: list[KeyMechanism],
length: int,
key_id: Optional[str] = None,
) -> str:
from .client.components.schema.key_generate_request_data import (
KeyGenerateRequestDataDict,
)
from .client.components.schema.key_mechanisms import KeyMechanismsTuple
from .client.components.schema.key_mechanisms import KeyMechanismsTupleInput

mechanism_tuple: KeyMechanismsTupleInput = [
mechanism.value for mechanism in mechanisms
]

if key_id:
body = KeyGenerateRequestDataDict(
type=type,
mechanisms=KeyMechanismsTuple(mechanisms),
type=type.value,
mechanisms=mechanism_tuple,
length=length,
id=key_id,
)
else:
body = KeyGenerateRequestDataDict(
type=type,
mechanisms=KeyMechanismsTuple(mechanisms),
type=type.value,
mechanisms=mechanism_tuple,
length=length,
)
try:
Expand Down Expand Up @@ -1107,7 +1132,7 @@ def csr(

def generate_tls_key(
self,
type: Literal["RSA", "Curve25519", "EC_P224", "EC_P256", "EC_P384", "EC_P521"],
type: TlsKeyType,
length: Optional[int] = None,
) -> None:
from .client.components.schema.tls_key_generate_request_data import (
Expand All @@ -1116,7 +1141,7 @@ def generate_tls_key(
from .client.schemas import Unset

body = TlsKeyGenerateRequestDataDict(
type=type,
type=type.value,
length=length if length is not None else Unset(),
)

Expand Down Expand Up @@ -1213,11 +1238,13 @@ def set_logging_config(
self,
ip_address: str,
port: int,
log_level: Literal["debug", "info", "warning", "error"],
log_level: LogLevel,
) -> None:
from .client.components.schema.logging_config import LoggingConfigDict

body = LoggingConfigDict(ipAddress=ip_address, port=port, logLevel=log_level)
body = LoggingConfigDict(
ipAddress=ip_address, port=port, logLevel=log_level.value
)
try:
self.get_api().config_logging_put(body=body)
except Exception as e:
Expand Down Expand Up @@ -1262,12 +1289,12 @@ def set_time(self, time: Union[str, datetime]) -> None:
},
)

def set_unattended_boot(self, status: Literal["on", "off"]) -> None:
def set_unattended_boot(self, status: UnattendedBootStatus) -> None:
from .client.components.schema.unattended_boot_config import (
UnattendedBootConfigDict,
)

body = UnattendedBootConfigDict(status=status)
body = UnattendedBootConfigDict(status=status.value)
try:
self.get_api().config_unattended_boot_put(body=body)
except Exception as e:
Expand Down Expand Up @@ -1408,7 +1435,7 @@ def factory_reset(self) -> None:
)

def encrypt(
self, key_id: str, data: str, mode: Literal["AES_CBC"], iv: str
self, key_id: str, data: str, mode: EncryptMode, iv: str
) -> tuple[str, str]:
from .client.components.schema.encrypt_request_data import (
EncryptRequestDataDict,
Expand All @@ -1418,7 +1445,7 @@ def encrypt(
)

path_params = PathParametersDict(KeyID=key_id)
body = EncryptRequestDataDict(message=data, mode=mode, iv=iv)
body = EncryptRequestDataDict(message=data, mode=mode.value, iv=iv)
try:
response = self.get_api().keys_key_id_encrypt_post(
path_params=path_params, body=body
Expand All @@ -1439,17 +1466,7 @@ def decrypt(
self,
key_id: str,
data: str,
mode: Literal[
"RAW",
"PKCS1",
"OAEP_MD5",
"OAEP_SHA1",
"OAEP_SHA224",
"OAEP_SHA256",
"OAEP_SHA384",
"OAEP_SHA512",
"AES_CBC",
],
mode: DecryptMode,
iv: str,
) -> str:
from .client.components.schema.decrypt_request_data import (
Expand All @@ -1459,10 +1476,10 @@ def decrypt(
PathParametersDict,
)

body = DecryptRequestDataDict(encrypted=data, mode=mode, iv=iv)
body = DecryptRequestDataDict(encrypted=data, mode=mode.value, iv=iv)

if len(iv) == 0:
body = DecryptRequestDataDict(encrypted=data, mode=mode)
body = DecryptRequestDataDict(encrypted=data, mode=mode.value)

path_params = PathParametersDict(KeyID=key_id)
try:
Expand All @@ -1485,25 +1502,15 @@ def sign(
self,
key_id: str,
data: str,
mode: Literal[
"PKCS1",
"PSS_MD5",
"PSS_SHA1",
"PSS_SHA224",
"PSS_SHA256",
"PSS_SHA384",
"PSS_SHA512",
"EdDSA",
"ECDSA",
],
mode: SignMode,
) -> str:
from .client.components.schema.sign_request_data import SignRequestDataDict
from .client.paths.keys_key_id_sign.post.path_parameters import (
PathParametersDict,
)

path_params = PathParametersDict(KeyID=key_id)
body = SignRequestDataDict(message=data, mode=mode)
body = SignRequestDataDict(message=data, mode=mode.value)
try:
response = self.get_api().keys_key_id_sign_post(
path_params=path_params, body=body
Expand Down
Loading

0 comments on commit 23eb940

Please sign in to comment.