diff --git a/requirements.txt b/requirements.txt index 4a17116..a94916c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,5 @@ coverage==6.3.2 # Utilities httpie==3.2.1 +flask-talisman +Flask-Cors diff --git a/service/__init__.py b/service/__init__.py index a62a9b3..350a3a3 100644 --- a/service/__init__.py +++ b/service/__init__.py @@ -6,6 +6,8 @@ """ import sys from flask import Flask +from flask_talisman import Talisman +from flask_cors import CORS from service import config from service.common import log_handlers @@ -13,6 +15,9 @@ app = Flask(__name__) app.config.from_object(config) +talisman = Talisman(app) +CORS(app) + # Import the routes After the Flask app is created # pylint: disable=wrong-import-position, cyclic-import, wrong-import-order from service import routes, models # noqa: F401 E402 diff --git a/tests/test_routes.py b/tests/test_routes.py index f1412e8..36eeb72 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -12,12 +12,14 @@ from service.common import status # HTTP Status Codes from service.models import db, Account, init_db from service.routes import app +from service import talisman DATABASE_URI = os.getenv( "DATABASE_URI", "postgresql://postgres:postgres@localhost:5432/postgres" ) BASE_URL = "/accounts" +HTTPS_ENVIRON = {'wsgi.url_scheme': 'https'} ###################################################################### @@ -34,6 +36,7 @@ def setUpClass(cls): app.config["SQLALCHEMY_DATABASE_URI"] = DATABASE_URI app.logger.setLevel(logging.CRITICAL) init_db(app) + talisman.force_https = False @classmethod def tearDownClass(cls): @@ -171,3 +174,23 @@ def test_get_account_list(self): self.assertEqual(resp.status_code, status.HTTP_200_OK) data = resp.get_json() self.assertEqual(len(data), 5) + + def test_security_headers(self): + """It should return security headers""" + response = self.client.get('/', environ_overrides=HTTPS_ENVIRON) + self.assertEqual(response.status_code, status.HTTP_200_OK) + headers = { + 'X-Frame-Options': 'SAMEORIGIN', + 'X-Content-Type-Options': 'nosniff', + 'Content-Security-Policy': 'default-src \'self\'; object-src \'none\'', + 'Referrer-Policy': 'strict-origin-when-cross-origin' + } + for key, value in headers.items(): + self.assertEqual(response.headers.get(key), value) + + def test_cors_security(self): + """It should return a CORS header""" + response = self.client.get('/', environ_overrides=HTTPS_ENVIRON) + self.assertEqual(response.status_code, status.HTTP_200_OK) + # Check for the CORS header + self.assertEqual(response.headers.get('Access-Control-Allow-Origin'), '*')