Skip to content

Commit

Permalink
Retrieve multiple secrets at a time or an entire table
Browse files Browse the repository at this point in the history
  • Loading branch information
dormant-user committed Sep 17, 2024
1 parent 1f5c811 commit c10f673
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 6 deletions.
18 changes: 18 additions & 0 deletions vaultapi/database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Tuple

from . import models


Expand All @@ -21,6 +23,22 @@ def get_secret(key: str, table_name: str) -> str | None:
return state[0]


def get_table(table_name: str) -> List[Tuple[str, str]]:
"""Function to retrieve all key-value pairs from a particular table in the database.
Args:
table_name: Name of the table where the secrets are stored.
Returns:
str:
Returns the secret value.
"""
with models.database.connection:
cursor = models.database.connection.cursor()
state = cursor.execute(f'SELECT * FROM "{table_name}"').fetchall()
return state


def put_secret(key: str, value: str, table_name: str) -> None:
"""Function to add secret to the database.
Expand Down
136 changes: 130 additions & 6 deletions vaultapi/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
security = HTTPBearer()


async def retrieve_existing(key: str, table_name: str) -> str | None:
"""Retrieve existing secret from database.
async def retrieve_secret(key: str, table_name: str) -> str | None:
"""Retrieve an existing secret from a table in the database.
Args:
key: Name of the secret to retrieve.
Expand All @@ -34,13 +34,40 @@ async def retrieve_existing(key: str, table_name: str) -> str | None:
)


async def retrieve_secrets(table_name: str, keys: List[str] = None) -> Dict[str, str]:
"""Retrieve multiple secrets from a table or retrieve the table as a whole.
Args:
table_name: Name of the table where the secret is stored.
keys: List of keys for which the values have to be retrieved.
Returns:
Dict[str, str]:
Returns the key-value pairs for secret key and it's value.
"""
if keys:
values = {}
for key in keys:
if value := await retrieve_secret(key, table_name):
values[key] = value
return values
else:
try:
return dict(database.get_table(table_name))
except sqlite3.OperationalError as error:
LOGGER.error(error)
raise exceptions.APIResponse(
status_code=HTTPStatus.BAD_REQUEST.real, detail=error.args[0]
)


async def get_secret(
request: Request,
key: str,
table_name: str = "default",
apikey: HTTPAuthorizationCredentials = Depends(security),
):
"""**API function to retrieve secrets.**
"""**API function to retrieve a secret.**
**Args:**
Expand All @@ -55,7 +82,7 @@ async def get_secret(
Raises the HTTPStatus object with a status code and detail as response.
"""
await auth.validate(request, apikey)
if value := await retrieve_existing(key, table_name):
if value := await retrieve_secret(key, table_name):
LOGGER.info("Secret value for '%s' was retrieved", key)
decrypted = models.session.fernet.decrypt(value).decode(encoding="UTF-8")
raise exceptions.APIResponse(status_code=HTTPStatus.OK.real, detail=decrypted)
Expand All @@ -65,6 +92,91 @@ async def get_secret(
)


async def get_secrets(
request: Request,
keys: List[str],
table_name: str = "default",
apikey: HTTPAuthorizationCredentials = Depends(security),
):
"""**API function to retrieve multiple secrets at a time.**
**Args:**
request: Reference to the FastAPI request object.
key: List of secret names to be retrieved.
table_name: Name of the table where the secrets are stored.
apikey: API Key to authenticate the request.
**Raises:**
APIResponse:
Raises the HTTPStatus object with a status code and detail as response.
"""
await auth.validate(request, apikey)
keys_ct = len(keys)
try:
assert keys_ct >= 1, f"Expected at least one key, received {keys_ct}"
except AssertionError as error:
LOGGER.error(error)
raise exceptions.APIResponse(
status_code=HTTPStatus.BAD_REQUEST.real, detail=error.args[0]
)
if values := await retrieve_secrets(table_name, keys):
values_ct = len(values)
try:
assert (
values_ct == keys_ct
), f"Number of keys [{keys_ct}] requested didn't match the number of values [{values_ct}] retrieved."
LOGGER.info("Secret value for %d (%s) were retrieved", keys_ct, keys)
code = HTTPStatus.OK.real
except AssertionError as error:
LOGGER.warning(error)
code = HTTPStatus.PARTIAL_CONTENT.real
decrypted = {
key: models.session.fernet.decrypt(value).decode(encoding="UTF-8")
for key, value in values.items()
}
raise exceptions.APIResponse(status_code=code, detail=decrypted)
if keys_ct == 1:
LOGGER.info("Secret value for '%s' NOT found in the datastore", keys[0])
else:
LOGGER.info(
"Secret values for %d keys (%s) were NOT found in the datastore",
keys_ct,
keys,
)
raise exceptions.APIResponse(
status_code=HTTPStatus.NOT_FOUND.real, detail=HTTPStatus.NOT_FOUND.phrase
)


async def get_table(
request: Request,
table_name: str = "default",
apikey: HTTPAuthorizationCredentials = Depends(security),
):
"""**API function to retrieve ALL the key-value pairs stored in a particular table.**
**Args:**
request: Reference to the FastAPI request object.
table_name: Name of the table where the secrets are stored.
apikey: API Key to authenticate the request.
**Raises:**
APIResponse:
Raises the HTTPStatus object with a status code and detail as response.
"""
await auth.validate(request, apikey)
table_content = await retrieve_secrets(table_name)
decrypted = {
key: models.session.fernet.decrypt(value).decode(encoding="UTF-8")
for key, value in table_content.items()
}
raise exceptions.APIResponse(status_code=HTTPStatus.OK.real, detail=decrypted)


async def put_secret(
request: Request,
data: payload.PutSecret,
Expand All @@ -84,7 +196,7 @@ async def put_secret(
Raises the HTTPStatus object with a status code and detail as response.
"""
await auth.validate(request, apikey)
if await retrieve_existing(data.key, data.table_name):
if await retrieve_secret(data.key, data.table_name):
LOGGER.info("Secret value for '%s' will be overridden", data.key)
else:
LOGGER.info(
Expand Down Expand Up @@ -118,7 +230,7 @@ async def delete_secret(
Raises the HTTPStatus object with a status code and detail as response.
"""
await auth.validate(request, apikey)
if await retrieve_existing(data.key, data.table_name):
if await retrieve_secret(data.key, data.table_name):
LOGGER.info("Secret value for '%s' will be removed", data.key)
else:
LOGGER.warning("Secret value for '%s' NOT found", data.key)
Expand Down Expand Up @@ -204,6 +316,18 @@ def get_all_routes() -> List[APIRoute]:
methods=["GET"],
dependencies=dependencies,
),
APIRoute(
path="/get-secrets",
endpoint=get_secrets,
methods=["POST"],
dependencies=dependencies,
),
APIRoute(
path="/get-table",
endpoint=get_table,
methods=["GET"],
dependencies=dependencies,
),
APIRoute(
path="/put-secret",
endpoint=put_secret,
Expand Down

0 comments on commit c10f673

Please sign in to comment.