Skip to content

Commit

Permalink
Add a feature to allow IP range when whitelisting
Browse files Browse the repository at this point in the history
Add docker stuff
  • Loading branch information
dormant-user committed Sep 18, 2024
1 parent ff03fd3 commit b29e28a
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 15 deletions.
6 changes: 6 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.env
*.db

docs/
doc_gen/
.github/
91 changes: 91 additions & 0 deletions .github/workflows/docker-description.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
name: Update Docker Hub Description

on:
push:
branches:
- main
paths:
- README.md
- .github/workflows/docker-description.yml

env:
DOCKER_REGISTRY: "https://hub.docker.com/v2"
DOCKER_REPOSITORY: "${{ github.event.repository.name }}"
DESCRIPTION_LIMIT: 100

jobs:
update-docker-hub:
runs-on: thevickypedia-lite
steps:
- uses: actions/checkout@v4

- name: Fetch API Token
run: |
payload=$(jq -n \
--arg username "${{ secrets.DOCKER_USERNAME }}" \
--arg password "${{ secrets.DOCKER_PASSWORD }}" \
'{username: $username, password: $password}')
token=$(curl -s -X POST "${{ env.DOCKER_REGISTRY }}/users/login/" \
-H "Content-Type: application/json" \
-d "$payload" | jq -r '.token')
if [[ -n "${token}" ]]; then
echo "::debug title=Token Retriever::Retrieved token successfully"
echo "API_TOKEN=${token}" >> $GITHUB_ENV
else
echo "::error title=Token Retriever::Failed to get auth token"
exit 1
fi
shell: bash

- name: Get Description
run: |
warn="Description exceeds DockerHub's limit and has been truncated to ${{ env.DESCRIPTION_LIMIT }} characters."
description="${{ github.event.repository.description }}"
description_length=${#description}
if [[ "$description_length" -gt "${{ env.DESCRIPTION_LIMIT }}" ]]; then
echo "::warning title=Description Too Long::${warn}"
shortened_description="${description:0:97}..."
else
shortened_description="$description"
fi
echo "SHORT_DESCRIPTION=${shortened_description}" >> $GITHUB_ENV
shell: bash

- name: Update description
run: |
full_description="$(cat README.md)"
payload=$(jq -n \
--arg description "${{ env.SHORT_DESCRIPTION }}" \
--arg full_description "$full_description" \
'{description: $description, full_description: $full_description}')
response=$(curl -s -o /tmp/desc -w "%{http_code}" -X PATCH \
"${{ env.DOCKER_REGISTRY }}/repositories/${{ secrets.DOCKER_USERNAME }}/${{ env.DOCKER_REPOSITORY }}/" \
-H "Authorization: Bearer ${{ env.API_TOKEN }}" \
-H "Content-Type: application/json" \
-d "$payload")
status_code="${response: -3}"
if [[ "${status_code}" -eq 200 ]]; then
echo "::notice title=Updater::Updated description successfully"
exit 0
elif [[ -f "/tmp/desc" ]]; then
echo "::error title=Updater::Failed to update description"
response_payload="$(cat /tmp/desc)"
reason=$(echo "${response_payload}" | jq '.message')
info=$(echo "${response_payload}" | jq '.errinfo')
if [[ "$reason" != "null" ]]; then
echo "::error title=Updater::[${status_code}]: $reason"
else
echo "::error title=Updater::[${status_code}]: $(cat /tmp/desc)"
fi
if [[ "$info" != "null" ]]; then
echo "::error title=Updater::${info}"
fi
else
echo "::error title=Updater::Failed to update description - ${status_code}"
fi
exit 1
shell: bash
30 changes: 30 additions & 0 deletions .github/workflows/docker-publish.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
name: Build and Publish

on:
release:
types:
- published

jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: ${{ github.repository }}:${{ github.event.release.tag_name }},${{ github.repository }}:latest
22 changes: 22 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
FROM python:3.11-alpine

WORKDIR /app

ADD LICENSE /app
ADD README.md /app
ADD pyproject.toml /app
ADD requirements.txt /app
ADD log_config.yml /app
ADD entrypoint.py /app
ADD vaultapi /app/vaultapi

RUN pwd && ls -ltrh

RUN python -m venv venv && \
source venv/bin/activate && \
python -m pip install .

# Add PATH env var, so the CLI is accessible
ENV PATH="/app/venv/bin:$PATH"

ENTRYPOINT [ "python", "entrypoint.py" ]
14 changes: 14 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
services:
app:
image: thevickypedia/VaultAPI
container_name: vaultapi
build:
context: .
volumes:
- ./logs:/app/logs
- ./data:/app/data
env_file:
- .env
ports:
# host_port:container_port
- "8080:9010"
67 changes: 67 additions & 0 deletions entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""This is an entrypoint specific for docker containers."""

import os
import pathlib
from datetime import datetime

import vaultapi

logs_dir = os.path.join(pathlib.Path(__file__).parent, "logs")
db_file = os.environ.get("database") or os.environ.get("DATABASE") or "secrets.db"
db_path = os.path.join(pathlib.Path(__file__).parent, "data", db_file)

DEFAULT_LOG_FILENAME: str = datetime.now().strftime(
os.path.join(logs_dir, "vaultapi_%d-%m-%Y.log")
)

os.makedirs(logs_dir, exist_ok=True)

log_config = {
"version": 1,
"disable_existing_loggers": True,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(asctime)s %(levelprefix)-9s %(name)s -: %(message)s",
"use_colors": False,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(asctime)s %(levelprefix)-9s %(name)s -: %(client_addr)s - "%(request_line)s" %(status_code)s',
"use_colors": False,
},
"error": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(asctime)s %(levelprefix)-9s %(name)s -: %(message)s",
"use_colors": False,
},
},
"handlers": {
"default": {
"class": "logging.FileHandler",
"formatter": "default",
"filename": DEFAULT_LOG_FILENAME,
},
"access": {
"class": "logging.FileHandler",
"formatter": "access",
"filename": DEFAULT_LOG_FILENAME,
},
"error": {
"class": "logging.FileHandler",
"formatter": "error",
"filename": DEFAULT_LOG_FILENAME,
},
},
"loggers": {
"uvicorn": {"propagate": True, "level": "INFO", "handlers": ["default"]},
"uvicorn.error": {"propagate": True, "level": "INFO", "handlers": ["error"]},
"uvicorn.access": {"propagate": True, "level": "INFO", "handlers": ["access"]},
},
}

if __name__ == '__main__':
vaultapi.start(
log_config=log_config,
database=db_path
)
48 changes: 48 additions & 0 deletions log_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#### This is a sample logging configuration for the API. ####
version: 1
disable_existing_loggers: True
formatters:
default:
(): 'uvicorn.logging.DefaultFormatter'
datefmt: '%b-%d-%Y %I:%M:%S %p'
fmt: '%(asctime)s %(levelprefix)-9s [%(module)s:%(lineno)d] - %(message)s'
use_colors: False
access:
(): 'uvicorn.logging.AccessFormatter'
datefmt: '%b-%d-%Y %I:%M:%S %p'
fmt: '%(asctime)s %(levelprefix)-9s [%(module)s:%(lineno)d] %(client_addr)s - %(status_code)s'
use_colors: False
error:
(): 'uvicorn.logging.DefaultFormatter'
datefmt: '%b-%d-%Y %I:%M:%S %p'
fmt: '%(asctime)s %(levelprefix)-9s [%(module)s:%(lineno)d] - %(message)s'
use_colors: False
handlers:
default:
class: logging.FileHandler # Can be changed to StreamHandler for stdout logging
formatter: default
filename: default.log
access:
class: logging.FileHandler # Can be changed to StreamHandler for stdout logging
formatter: access
filename: access.log
error:
class: logging.FileHandler # Can be changed to StreamHandler for stdout logging
formatter: error
filename: default.log
loggers:
uvicorn:
propagate: True
level: INFO
handlers:
- default
uvicorn.error:
propagate: True
level: INFO
handlers:
- error
uvicorn.access:
propagate: True
level: INFO
handlers:
- access
4 changes: 2 additions & 2 deletions vaultapi/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ async def validate(request: Request, apikey: HTTPAuthorizationCredentials) -> No
else:
auth = apikey.credentials
if secrets.compare_digest(auth, models.env.apikey):
LOGGER.info(
LOGGER.debug(
"Connection received from client-host: %s, host-header: %s, x-fwd-host: %s",
request.client.host,
request.headers.get("host"),
request.headers.get("x-forwarded-host"),
)
if user_agent := request.headers.get("user-agent"):
LOGGER.info("User agent: %s", user_agent)
LOGGER.debug("User agent: %s", user_agent)
return
raise exceptions.APIResponse(
status_code=HTTPStatus.UNAUTHORIZED.real, detail=HTTPStatus.UNAUTHORIZED.phrase
Expand Down
7 changes: 7 additions & 0 deletions vaultapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(**kwargs) -> None:
models.session.allowed_origins.add(models.env.host)
for allowed in models.env.allowed_origins:
models.session.allowed_origins.add(allowed.host)
for cidr_range in models.env.allowed_ip_range:
LOGGER.info("Adding the IP range: %s to allowed_origins", cidr_range)
ip_notion = '.'.join(cidr_range.split('.')[0:-1])
start_ip, end_ip = cidr_range.split('.')[-1].split('-')
start_ip, end_ip = int(start_ip), int(end_ip) + 1
for i in range(start_ip, end_ip):
models.session.allowed_origins.add(f"{ip_notion}.{i}")
LOGGER.info("Allowed origins: %s", models.session.allowed_origins)


Expand Down
36 changes: 26 additions & 10 deletions vaultapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,23 @@ class EnvConfig(BaseSettings):
secret: str
database: FilePath | NewPath | str = Field("secrets.db", pattern=".*.db$")
host: str = socket.gethostbyname("localhost") or "0.0.0.0"
port: PositiveInt = 8080
port: PositiveInt = 9010
workers: PositiveInt = 1
log_config: FilePath | Dict[str, Any] | None = None
allowed_origins: HttpUrl | List[HttpUrl] = []
allowed_ip_range: List[str] = []
# This is a base rate limit configuration
rate_limit: RateLimit | List[RateLimit] = [
# Burst limit: Prevents excessive load on the server
{
"max_requests": 5,
"seconds": 2,
}, # Burst limit: Prevents excessive load on the server
},
# Sustained limit: Prevents too many trial and errors
{
"max_requests": 10,
"seconds": 30,
}, # Sustained limit: Prevents too many trial and errors
},
]

@field_validator("allowed_origins", mode="after", check_fields=True)
Expand All @@ -143,15 +146,28 @@ def parse_allowed_origins(
return value
return [value]

@field_validator("apikey", mode="after")
def parse_apikey(cls, value: str | None) -> str | None: # noqa: PyMethodParameters
"""Parse API key to validate complexity."""
if value:
@field_validator("allowed_ip_range", mode="after", check_fields=True)
def parse_allowed_ip_range(
cls, value: List[str] # noqa: PyMethodParameters
) -> List[str]:
"""Validate allowed IP range to whitelist."""
for ip_range in value:
try:
complexity_checker(value)
assert len(ip_range.split('.')) > 1, f"Expected a valid IP address, received {ip_range}"
assert len(ip_range.split('.')[-1].split('-')) == 2, f"Expected a valid IP range, received {ip_range}"
except AssertionError as error:
raise ValueError(error.__str__())
return value
exc = f"{error}\n\tInput should be a list of IP range (eg: ['192.168.1.10-19', '10.120.1.5-35'])"
raise ValueError(exc)
return value

@field_validator("apikey", mode="after")
def parse_apikey(cls, value: str) -> str | None: # noqa: PyMethodParameters
"""Parse API key to validate complexity."""
try:
complexity_checker(value)
except AssertionError as error:
raise ValueError(error.__str__())
return value

@field_validator("secret", mode="after")
def parse_api_secret(cls, value: str) -> str: # noqa: PyMethodParameters
Expand Down
6 changes: 3 additions & 3 deletions vaultapi/rate_limit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import math
import time
from collections import defaultdict
from http import HTTPStatus
from threading import Lock

Expand Down Expand Up @@ -36,8 +36,8 @@ def __init__(self, rps: models.RateLimit):
"""
self.max_requests = rps.max_requests
self.seconds = rps.seconds
self.locks = defaultdict(Lock) # For thread-safe access
self.requests = defaultdict(list)
self.locks = collections.defaultdict(Lock) # For thread-safe access
self.requests = collections.defaultdict(list)

def init(self, request: Request) -> None:
"""Checks if the number of calls exceeds the rate limit for the given identifier.
Expand Down

0 comments on commit b29e28a

Please sign in to comment.