From 182f9e11098db29c192d026f6f555eda6138b7c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C4=93teris=20Caune?= Date: Mon, 29 Oct 2018 18:34:58 +0200 Subject: [PATCH] Refactor API key checking code --- CHANGELOG.md | 1 + hc/api/decorators.py | 46 +++++++++++++++++-------------- hc/api/tests/test_create_check.py | 42 ++++++++++++++-------------- hc/api/tests/test_delete_check.py | 4 +-- hc/api/tests/test_list_checks.py | 12 ++++---- hc/api/tests/test_pause.py | 10 +++---- hc/api/tests/test_update_check.py | 16 +++++------ hc/api/views.py | 5 ++-- hc/test.py | 2 +- 9 files changed, 71 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71e3d35c..72fe5c73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ All notable changes to this project will be documented in this file. - Add "List-Unsubscribe" header to alert and report emails - Don't send monthly reports to inactive accounts (no pings in 6 months) - Add search box in the "My Checks" page +- Refactor API key checking code ### Bug Fixes - During DST transition, handle ambiguous dates as pre-transition diff --git a/hc/api/decorators.py b/hc/api/decorators.py index c93aa212..577bf48e 100644 --- a/hc/api/decorators.py +++ b/hc/api/decorators.py @@ -2,56 +2,60 @@ import json from functools import wraps from django.contrib.auth.models import User -from django.http import HttpResponseForbidden, JsonResponse +from django.http import JsonResponse from hc.lib.jsonschema import ValidationError, validate -def make_error(msg): - return JsonResponse({"error": msg}, status=400) +def error(msg, status=400): + return JsonResponse({"error": msg}, status=status) def check_api_key(f): @wraps(f) def wrapper(request, *args, **kwds): - request.json = {} - if request.body: - try: - request.json = json.loads(request.body.decode()) - except ValueError: - return make_error("could not parse request body") - if "HTTP_X_API_KEY" in request.META: api_key = request.META["HTTP_X_API_KEY"] else: - api_key = request.json.get("api_key", "") + api_key = str(request.json.get("api_key", "")) - if api_key == "": - return make_error("wrong api_key") + if len(api_key) != 32: + return error("missing api key", 401) try: request.user = User.objects.get(profile__api_key=api_key) except User.DoesNotExist: - return HttpResponseForbidden() + return error("wrong api key", 401) return f(request, *args, **kwds) return wrapper -def validate_json(schema): - """ Validate request.json contents against `schema`. +def validate_json(schema=None): + """ Parse request json and validate it against `schema`. - Supports a tiny subset of JSON schema spec. + Put the parsed result in `request.json`. + If schema is None then only parse and don't validate. + Supports a limited subset of JSON schema spec. """ def decorator(f): @wraps(f) def wrapper(request, *args, **kwds): - try: - validate(request.json, schema) - except ValidationError as e: - return make_error("json validation error: %s" % e) + if request.body: + try: + request.json = json.loads(request.body.decode()) + except ValueError: + return error("could not parse request body") + else: + request.json = {} + + if schema: + try: + validate(request.json, schema) + except ValidationError as e: + return error("json validation error: %s" % e) return f(request, *args, **kwds) return wrapper diff --git a/hc/api/tests/test_create_check.py b/hc/api/tests/test_create_check.py index 0bfb6245..f3bd03b7 100644 --- a/hc/api/tests/test_create_check.py +++ b/hc/api/tests/test_create_check.py @@ -22,7 +22,7 @@ class CreateCheckTestCase(BaseTestCase): def test_it_works(self): r = self.post({ - "api_key": "abc", + "api_key": "X" * 32, "name": "Foo", "tags": "bar,baz", "timeout": 3600, @@ -49,7 +49,7 @@ class CreateCheckTestCase(BaseTestCase): def test_30_days_works(self): r = self.post({ - "api_key": "abc", + "api_key": "X" * 32, "name": "Foo", "timeout": 2592000, "grace": 2592000 @@ -65,7 +65,7 @@ class CreateCheckTestCase(BaseTestCase): payload = json.dumps({"name": "Foo"}) r = self.client.post(self.URL, payload, content_type="application/json", - HTTP_X_API_KEY="abc") + HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 201) @@ -73,7 +73,7 @@ class CreateCheckTestCase(BaseTestCase): channel = Channel(user=self.alice) channel.save() - r = self.post({"api_key": "abc", "channels": "*"}) + r = self.post({"api_key": "X" * 32, "channels": "*"}) self.assertEqual(r.status_code, 201) check = Check.objects.get() @@ -84,7 +84,7 @@ class CreateCheckTestCase(BaseTestCase): existing.save() r = self.post({ - "api_key": "abc", + "api_key": "X" * 32, "name": "Foo", "unique": ["name"] }) @@ -97,8 +97,8 @@ class CreateCheckTestCase(BaseTestCase): def test_it_handles_missing_request_body(self): r = self.client.post(self.URL, content_type="application/json") - self.assertEqual(r.status_code, 400) - self.assertEqual(r.json()["error"], "wrong api_key") + self.assertEqual(r.status_code, 401) + self.assertEqual(r.json()["error"], "missing api key") def test_it_handles_invalid_json(self): r = self.client.post(self.URL, "this is not json", @@ -107,27 +107,27 @@ class CreateCheckTestCase(BaseTestCase): self.assertEqual(r.json()["error"], "could not parse request body") def test_it_rejects_wrong_api_key(self): - r = self.post({"api_key": "wrong"}) - self.assertEqual(r.status_code, 403) + r = self.post({"api_key": "Y" * 32}) + self.assertEqual(r.status_code, 401) def test_it_rejects_small_timeout(self): - self.post({"api_key": "abc", "timeout": 0}, + self.post({"api_key": "X" * 32, "timeout": 0}, expected_fragment="timeout is too small") def test_it_rejects_large_timeout(self): - self.post({"api_key": "abc", "timeout": 2592001}, + self.post({"api_key": "X" * 32, "timeout": 2592001}, expected_fragment="timeout is too large") def test_it_rejects_non_number_timeout(self): - self.post({"api_key": "abc", "timeout": "oops"}, + self.post({"api_key": "X" * 32, "timeout": "oops"}, expected_fragment="timeout is not a number") def test_it_rejects_non_string_name(self): - self.post({"api_key": "abc", "name": False}, + self.post({"api_key": "X" * 32, "name": False}, expected_fragment="name is not a string") def test_it_rejects_long_name(self): - self.post({"api_key": "abc", "name": "01234567890" * 20}, + self.post({"api_key": "X" * 32, "name": "01234567890" * 20}, expected_fragment="name is too long") def test_unique_accepts_only_whitelisted_values(self): @@ -135,21 +135,21 @@ class CreateCheckTestCase(BaseTestCase): existing.save() self.post({ - "api_key": "abc", + "api_key": "X" * 32, "name": "Foo", "unique": ["status"] }, expected_fragment="unexpected value") def test_it_rejects_bad_unique_values(self): self.post({ - "api_key": "abc", + "api_key": "X" * 32, "name": "Foo", "unique": "not a list" }, expected_fragment="not an array") def test_it_supports_cron_syntax(self): r = self.post({ - "api_key": "abc", + "api_key": "X" * 32, "schedule": "5 * * * *", "tz": "Europe/Riga", "grace": 60 @@ -166,7 +166,7 @@ class CreateCheckTestCase(BaseTestCase): def test_it_validates_cron_expression(self): r = self.post({ - "api_key": "abc", + "api_key": "X" * 32, "schedule": "not-a-cron-expression", "tz": "Europe/Riga", "grace": 60 @@ -176,7 +176,7 @@ class CreateCheckTestCase(BaseTestCase): def test_it_validates_timezone(self): r = self.post({ - "api_key": "abc", + "api_key": "X" * 32, "schedule": "* * * * *", "tz": "not-a-timezone", "grace": 60 @@ -185,7 +185,7 @@ class CreateCheckTestCase(BaseTestCase): self.assertEqual(r.status_code, 400) def test_it_sets_default_timeout(self): - r = self.post({"api_key": "abc"}) + r = self.post({"api_key": "X" * 32}) self.assertEqual(r.status_code, 201) @@ -196,5 +196,5 @@ class CreateCheckTestCase(BaseTestCase): self.profile.check_limit = 0 self.profile.save() - r = self.post({"api_key": "abc"}) + r = self.post({"api_key": "X" * 32}) self.assertEqual(r.status_code, 403) diff --git a/hc/api/tests/test_delete_check.py b/hc/api/tests/test_delete_check.py index ad6fe72f..0531ae3e 100644 --- a/hc/api/tests/test_delete_check.py +++ b/hc/api/tests/test_delete_check.py @@ -11,7 +11,7 @@ class DeleteCheckTestCase(BaseTestCase): def test_it_works(self): r = self.client.delete("/api/v1/checks/%s" % self.check.code, - HTTP_X_API_KEY="abc") + HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) # It should be gone-- @@ -19,5 +19,5 @@ class DeleteCheckTestCase(BaseTestCase): def test_it_handles_missing_check(self): url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02" - r = self.client.delete(url, HTTP_X_API_KEY="abc") + r = self.client.delete(url, HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 404) diff --git a/hc/api/tests/test_list_checks.py b/hc/api/tests/test_list_checks.py index db060379..d1557608 100644 --- a/hc/api/tests/test_list_checks.py +++ b/hc/api/tests/test_list_checks.py @@ -32,7 +32,7 @@ class ListChecksTestCase(BaseTestCase): self.a2.save() def get(self): - return self.client.get("/api/v1/checks/", HTTP_X_API_KEY="abc") + return self.client.get("/api/v1/checks/", HTTP_X_API_KEY="X" * 32) def test_it_works(self): r = self.get() @@ -75,7 +75,7 @@ class ListChecksTestCase(BaseTestCase): self.assertNotEqual(check["name"], "Bob 1") def test_it_accepts_api_key_from_request_body(self): - payload = json.dumps({"api_key": "abc"}) + payload = json.dumps({"api_key": "X" * 32}) r = self.client.generic("GET", "/api/v1/checks/", payload, content_type="application/json") @@ -83,7 +83,7 @@ class ListChecksTestCase(BaseTestCase): self.assertContains(r, "Alice") def test_it_works_with_tags_param(self): - r = self.client.get("/api/v1/checks/?tag=a2-tag", HTTP_X_API_KEY="abc") + r = self.client.get("/api/v1/checks/?tag=a2-tag", HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) doc = r.json() @@ -96,7 +96,7 @@ class ListChecksTestCase(BaseTestCase): self.assertEqual(check["tags"], "a2-tag") def test_it_filters_with_multiple_tags_param(self): - r = self.client.get("/api/v1/checks/?tag=a1-tag&tag=a1-additional-tag", HTTP_X_API_KEY="abc") + r = self.client.get("/api/v1/checks/?tag=a1-tag&tag=a1-additional-tag", HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) doc = r.json() @@ -109,7 +109,7 @@ class ListChecksTestCase(BaseTestCase): self.assertEqual(check["tags"], "a1-tag a1-additional-tag") def test_it_does_not_match_tag_partially(self): - r = self.client.get("/api/v1/checks/?tag=tag", HTTP_X_API_KEY="abc") + r = self.client.get("/api/v1/checks/?tag=tag", HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) doc = r.json() @@ -117,7 +117,7 @@ class ListChecksTestCase(BaseTestCase): self.assertEqual(len(doc["checks"]), 0) def test_non_existing_tags_filter_returns_empty_result(self): - r = self.client.get("/api/v1/checks/?tag=non_existing_tag_with_no_checks", HTTP_X_API_KEY="abc") + r = self.client.get("/api/v1/checks/?tag=non_existing_tag_with_no_checks", HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) doc = r.json() diff --git a/hc/api/tests/test_pause.py b/hc/api/tests/test_pause.py index 4ca67faf..abe2a7b1 100644 --- a/hc/api/tests/test_pause.py +++ b/hc/api/tests/test_pause.py @@ -10,7 +10,7 @@ class PauseTestCase(BaseTestCase): url = "/api/v1/checks/%s/pause" % check.code r = self.client.post(url, "", content_type="application/json", - HTTP_X_API_KEY="abc") + HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) @@ -20,7 +20,7 @@ class PauseTestCase(BaseTestCase): def test_it_only_allows_post(self): url = "/api/v1/checks/1659718b-21ad-4ed1-8740-43afc6c41524/pause" - r = self.client.get(url, HTTP_X_API_KEY="abc") + r = self.client.get(url, HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 405) def test_it_validates_ownership(self): @@ -29,20 +29,20 @@ class PauseTestCase(BaseTestCase): url = "/api/v1/checks/%s/pause" % check.code r = self.client.post(url, "", content_type="application/json", - HTTP_X_API_KEY="abc") + HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 403) def test_it_validates_uuid(self): url = "/api/v1/checks/not-uuid/pause" r = self.client.post(url, "", content_type="application/json", - HTTP_X_API_KEY="abc") + HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 404) def test_it_handles_missing_check(self): url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02/pause" r = self.client.post(url, "", content_type="application/json", - HTTP_X_API_KEY="abc") + HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 404) diff --git a/hc/api/tests/test_update_check.py b/hc/api/tests/test_update_check.py index 67629bc1..0b5e85e9 100644 --- a/hc/api/tests/test_update_check.py +++ b/hc/api/tests/test_update_check.py @@ -1,5 +1,3 @@ -import json - from hc.api.models import Channel, Check from hc.test import BaseTestCase @@ -17,7 +15,7 @@ class UpdateCheckTestCase(BaseTestCase): def test_it_works(self): r = self.post(self.check.code, { - "api_key": "abc", + "api_key": "X" * 32, "name": "Foo", "tags": "bar,baz", "timeout": 3600, @@ -51,7 +49,7 @@ class UpdateCheckTestCase(BaseTestCase): self.check.assign_all_channels() r = self.post(self.check.code, { - "api_key": "abc", + "api_key": "X" * 32, "channels": "" }) @@ -61,23 +59,23 @@ class UpdateCheckTestCase(BaseTestCase): def test_it_requires_post(self): url = "/api/v1/checks/%s" % self.check.code - r = self.client.get(url, HTTP_X_API_KEY="abc") + r = self.client.get(url, HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 405) def test_it_handles_invalid_uuid(self): - r = self.post("not-an-uuid", {"api_key": "abc"}) + r = self.post("not-an-uuid", {"api_key": "X" * 32}) self.assertEqual(r.status_code, 404) def test_it_handles_missing_check(self): made_up_code = "07c2f548-9850-4b27-af5d-6c9dc157ec02" - r = self.post(made_up_code, {"api_key": "abc"}) + r = self.post(made_up_code, {"api_key": "X" * 32}) self.assertEqual(r.status_code, 404) def test_it_validates_ownership(self): check = Check(user=self.bob, status="up") check.save() - r = self.post(check.code, {"api_key": "abc"}) + r = self.post(check.code, {"api_key": "X" * 32}) self.assertEqual(r.status_code, 403) def test_it_updates_cron_to_simple(self): @@ -85,7 +83,7 @@ class UpdateCheckTestCase(BaseTestCase): self.check.schedule = "5 * * * *" self.check.save() - r = self.post(self.check.code, {"api_key": "abc", "timeout": 3600}) + r = self.post(self.check.code, {"api_key": "X" * 32, "timeout": 3600}) self.assertEqual(r.status_code, 200) self.check.refresh_from_db() diff --git a/hc/api/views.py b/hc/api/views.py index 0c0b4a51..b2277975 100644 --- a/hc/api/views.py +++ b/hc/api/views.py @@ -88,8 +88,8 @@ def _update(check, spec): @csrf_exempt -@check_api_key @validate_json(schemas.check) +@check_api_key def checks(request): if request.method == "GET": q = Check.objects.filter(user=request.user) @@ -127,8 +127,8 @@ def checks(request): @csrf_exempt -@check_api_key @validate_json(schemas.check) +@check_api_key def update(request, code): check = get_object_or_404(Check, code=code) if check.user != request.user: @@ -149,6 +149,7 @@ def update(request, code): @csrf_exempt @require_POST +@validate_json() @check_api_key def pause(request, code): check = get_object_or_404(Check, code=code) diff --git a/hc/test.py b/hc/test.py index d92b567a..0bad5ee0 100644 --- a/hc/test.py +++ b/hc/test.py @@ -14,7 +14,7 @@ class BaseTestCase(TestCase): self.alice.set_password("password") self.alice.save() - self.profile = Profile(user=self.alice, api_key="abc") + self.profile = Profile(user=self.alice, api_key="X" * 32) self.profile.sms_limit = 50 self.profile.save()