Browse Source

Refactor API key checking code

pull/199/head
Pēteris Caune 6 years ago
parent
commit
182f9e1109
No known key found for this signature in database GPG Key ID: E28D7679E9A9EDE2
9 changed files with 71 additions and 67 deletions
  1. +1
    -0
      CHANGELOG.md
  2. +25
    -21
      hc/api/decorators.py
  3. +21
    -21
      hc/api/tests/test_create_check.py
  4. +2
    -2
      hc/api/tests/test_delete_check.py
  5. +6
    -6
      hc/api/tests/test_list_checks.py
  6. +5
    -5
      hc/api/tests/test_pause.py
  7. +7
    -9
      hc/api/tests/test_update_check.py
  8. +3
    -2
      hc/api/views.py
  9. +1
    -1
      hc/test.py

+ 1
- 0
CHANGELOG.md View File

@ -8,6 +8,7 @@ All notable changes to this project will be documented in this file.
- Add "List-Unsubscribe" header to alert and report emails - Add "List-Unsubscribe" header to alert and report emails
- Don't send monthly reports to inactive accounts (no pings in 6 months) - Don't send monthly reports to inactive accounts (no pings in 6 months)
- Add search box in the "My Checks" page - Add search box in the "My Checks" page
- Refactor API key checking code
### Bug Fixes ### Bug Fixes
- During DST transition, handle ambiguous dates as pre-transition - During DST transition, handle ambiguous dates as pre-transition


+ 25
- 21
hc/api/decorators.py View File

@ -2,56 +2,60 @@ import json
from functools import wraps from functools import wraps
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import HttpResponseForbidden, JsonResponse
from django.http import JsonResponse
from hc.lib.jsonschema import ValidationError, validate from hc.lib.jsonschema import ValidationError, validate
def make_error(msg):
return JsonResponse({"error": msg}, status=400)
def error(msg, status=400):
return JsonResponse({"error": msg}, status=status)
def check_api_key(f): def check_api_key(f):
@wraps(f) @wraps(f)
def wrapper(request, *args, **kwds): def wrapper(request, *args, **kwds):
request.json = {}
if request.body:
try:
request.json = json.loads(request.body.decode())
except ValueError:
return make_error("could not parse request body")
if "HTTP_X_API_KEY" in request.META: if "HTTP_X_API_KEY" in request.META:
api_key = request.META["HTTP_X_API_KEY"] api_key = request.META["HTTP_X_API_KEY"]
else: else:
api_key = request.json.get("api_key", "")
api_key = str(request.json.get("api_key", ""))
if api_key == "":
return make_error("wrong api_key")
if len(api_key) != 32:
return error("missing api key", 401)
try: try:
request.user = User.objects.get(profile__api_key=api_key) request.user = User.objects.get(profile__api_key=api_key)
except User.DoesNotExist: except User.DoesNotExist:
return HttpResponseForbidden()
return error("wrong api key", 401)
return f(request, *args, **kwds) return f(request, *args, **kwds)
return wrapper return wrapper
def validate_json(schema):
""" Validate request.json contents against `schema`.
def validate_json(schema=None):
""" Parse request json and validate it against `schema`.
Supports a tiny subset of JSON schema spec.
Put the parsed result in `request.json`.
If schema is None then only parse and don't validate.
Supports a limited subset of JSON schema spec.
""" """
def decorator(f): def decorator(f):
@wraps(f) @wraps(f)
def wrapper(request, *args, **kwds): def wrapper(request, *args, **kwds):
try:
validate(request.json, schema)
except ValidationError as e:
return make_error("json validation error: %s" % e)
if request.body:
try:
request.json = json.loads(request.body.decode())
except ValueError:
return error("could not parse request body")
else:
request.json = {}
if schema:
try:
validate(request.json, schema)
except ValidationError as e:
return error("json validation error: %s" % e)
return f(request, *args, **kwds) return f(request, *args, **kwds)
return wrapper return wrapper


+ 21
- 21
hc/api/tests/test_create_check.py View File

@ -22,7 +22,7 @@ class CreateCheckTestCase(BaseTestCase):
def test_it_works(self): def test_it_works(self):
r = self.post({ r = self.post({
"api_key": "abc",
"api_key": "X" * 32,
"name": "Foo", "name": "Foo",
"tags": "bar,baz", "tags": "bar,baz",
"timeout": 3600, "timeout": 3600,
@ -49,7 +49,7 @@ class CreateCheckTestCase(BaseTestCase):
def test_30_days_works(self): def test_30_days_works(self):
r = self.post({ r = self.post({
"api_key": "abc",
"api_key": "X" * 32,
"name": "Foo", "name": "Foo",
"timeout": 2592000, "timeout": 2592000,
"grace": 2592000 "grace": 2592000
@ -65,7 +65,7 @@ class CreateCheckTestCase(BaseTestCase):
payload = json.dumps({"name": "Foo"}) payload = json.dumps({"name": "Foo"})
r = self.client.post(self.URL, payload, r = self.client.post(self.URL, payload,
content_type="application/json", content_type="application/json",
HTTP_X_API_KEY="abc")
HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 201) self.assertEqual(r.status_code, 201)
@ -73,7 +73,7 @@ class CreateCheckTestCase(BaseTestCase):
channel = Channel(user=self.alice) channel = Channel(user=self.alice)
channel.save() channel.save()
r = self.post({"api_key": "abc", "channels": "*"})
r = self.post({"api_key": "X" * 32, "channels": "*"})
self.assertEqual(r.status_code, 201) self.assertEqual(r.status_code, 201)
check = Check.objects.get() check = Check.objects.get()
@ -84,7 +84,7 @@ class CreateCheckTestCase(BaseTestCase):
existing.save() existing.save()
r = self.post({ r = self.post({
"api_key": "abc",
"api_key": "X" * 32,
"name": "Foo", "name": "Foo",
"unique": ["name"] "unique": ["name"]
}) })
@ -97,8 +97,8 @@ class CreateCheckTestCase(BaseTestCase):
def test_it_handles_missing_request_body(self): def test_it_handles_missing_request_body(self):
r = self.client.post(self.URL, content_type="application/json") r = self.client.post(self.URL, content_type="application/json")
self.assertEqual(r.status_code, 400)
self.assertEqual(r.json()["error"], "wrong api_key")
self.assertEqual(r.status_code, 401)
self.assertEqual(r.json()["error"], "missing api key")
def test_it_handles_invalid_json(self): def test_it_handles_invalid_json(self):
r = self.client.post(self.URL, "this is not json", r = self.client.post(self.URL, "this is not json",
@ -107,27 +107,27 @@ class CreateCheckTestCase(BaseTestCase):
self.assertEqual(r.json()["error"], "could not parse request body") self.assertEqual(r.json()["error"], "could not parse request body")
def test_it_rejects_wrong_api_key(self): def test_it_rejects_wrong_api_key(self):
r = self.post({"api_key": "wrong"})
self.assertEqual(r.status_code, 403)
r = self.post({"api_key": "Y" * 32})
self.assertEqual(r.status_code, 401)
def test_it_rejects_small_timeout(self): def test_it_rejects_small_timeout(self):
self.post({"api_key": "abc", "timeout": 0},
self.post({"api_key": "X" * 32, "timeout": 0},
expected_fragment="timeout is too small") expected_fragment="timeout is too small")
def test_it_rejects_large_timeout(self): def test_it_rejects_large_timeout(self):
self.post({"api_key": "abc", "timeout": 2592001},
self.post({"api_key": "X" * 32, "timeout": 2592001},
expected_fragment="timeout is too large") expected_fragment="timeout is too large")
def test_it_rejects_non_number_timeout(self): def test_it_rejects_non_number_timeout(self):
self.post({"api_key": "abc", "timeout": "oops"},
self.post({"api_key": "X" * 32, "timeout": "oops"},
expected_fragment="timeout is not a number") expected_fragment="timeout is not a number")
def test_it_rejects_non_string_name(self): def test_it_rejects_non_string_name(self):
self.post({"api_key": "abc", "name": False},
self.post({"api_key": "X" * 32, "name": False},
expected_fragment="name is not a string") expected_fragment="name is not a string")
def test_it_rejects_long_name(self): def test_it_rejects_long_name(self):
self.post({"api_key": "abc", "name": "01234567890" * 20},
self.post({"api_key": "X" * 32, "name": "01234567890" * 20},
expected_fragment="name is too long") expected_fragment="name is too long")
def test_unique_accepts_only_whitelisted_values(self): def test_unique_accepts_only_whitelisted_values(self):
@ -135,21 +135,21 @@ class CreateCheckTestCase(BaseTestCase):
existing.save() existing.save()
self.post({ self.post({
"api_key": "abc",
"api_key": "X" * 32,
"name": "Foo", "name": "Foo",
"unique": ["status"] "unique": ["status"]
}, expected_fragment="unexpected value") }, expected_fragment="unexpected value")
def test_it_rejects_bad_unique_values(self): def test_it_rejects_bad_unique_values(self):
self.post({ self.post({
"api_key": "abc",
"api_key": "X" * 32,
"name": "Foo", "name": "Foo",
"unique": "not a list" "unique": "not a list"
}, expected_fragment="not an array") }, expected_fragment="not an array")
def test_it_supports_cron_syntax(self): def test_it_supports_cron_syntax(self):
r = self.post({ r = self.post({
"api_key": "abc",
"api_key": "X" * 32,
"schedule": "5 * * * *", "schedule": "5 * * * *",
"tz": "Europe/Riga", "tz": "Europe/Riga",
"grace": 60 "grace": 60
@ -166,7 +166,7 @@ class CreateCheckTestCase(BaseTestCase):
def test_it_validates_cron_expression(self): def test_it_validates_cron_expression(self):
r = self.post({ r = self.post({
"api_key": "abc",
"api_key": "X" * 32,
"schedule": "not-a-cron-expression", "schedule": "not-a-cron-expression",
"tz": "Europe/Riga", "tz": "Europe/Riga",
"grace": 60 "grace": 60
@ -176,7 +176,7 @@ class CreateCheckTestCase(BaseTestCase):
def test_it_validates_timezone(self): def test_it_validates_timezone(self):
r = self.post({ r = self.post({
"api_key": "abc",
"api_key": "X" * 32,
"schedule": "* * * * *", "schedule": "* * * * *",
"tz": "not-a-timezone", "tz": "not-a-timezone",
"grace": 60 "grace": 60
@ -185,7 +185,7 @@ class CreateCheckTestCase(BaseTestCase):
self.assertEqual(r.status_code, 400) self.assertEqual(r.status_code, 400)
def test_it_sets_default_timeout(self): def test_it_sets_default_timeout(self):
r = self.post({"api_key": "abc"})
r = self.post({"api_key": "X" * 32})
self.assertEqual(r.status_code, 201) self.assertEqual(r.status_code, 201)
@ -196,5 +196,5 @@ class CreateCheckTestCase(BaseTestCase):
self.profile.check_limit = 0 self.profile.check_limit = 0
self.profile.save() self.profile.save()
r = self.post({"api_key": "abc"})
r = self.post({"api_key": "X" * 32})
self.assertEqual(r.status_code, 403) self.assertEqual(r.status_code, 403)

+ 2
- 2
hc/api/tests/test_delete_check.py View File

@ -11,7 +11,7 @@ class DeleteCheckTestCase(BaseTestCase):
def test_it_works(self): def test_it_works(self):
r = self.client.delete("/api/v1/checks/%s" % self.check.code, r = self.client.delete("/api/v1/checks/%s" % self.check.code,
HTTP_X_API_KEY="abc")
HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
# It should be gone-- # It should be gone--
@ -19,5 +19,5 @@ class DeleteCheckTestCase(BaseTestCase):
def test_it_handles_missing_check(self): def test_it_handles_missing_check(self):
url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02" url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02"
r = self.client.delete(url, HTTP_X_API_KEY="abc")
r = self.client.delete(url, HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 404) self.assertEqual(r.status_code, 404)

+ 6
- 6
hc/api/tests/test_list_checks.py View File

@ -32,7 +32,7 @@ class ListChecksTestCase(BaseTestCase):
self.a2.save() self.a2.save()
def get(self): def get(self):
return self.client.get("/api/v1/checks/", HTTP_X_API_KEY="abc")
return self.client.get("/api/v1/checks/", HTTP_X_API_KEY="X" * 32)
def test_it_works(self): def test_it_works(self):
r = self.get() r = self.get()
@ -75,7 +75,7 @@ class ListChecksTestCase(BaseTestCase):
self.assertNotEqual(check["name"], "Bob 1") self.assertNotEqual(check["name"], "Bob 1")
def test_it_accepts_api_key_from_request_body(self): def test_it_accepts_api_key_from_request_body(self):
payload = json.dumps({"api_key": "abc"})
payload = json.dumps({"api_key": "X" * 32})
r = self.client.generic("GET", "/api/v1/checks/", payload, r = self.client.generic("GET", "/api/v1/checks/", payload,
content_type="application/json") content_type="application/json")
@ -83,7 +83,7 @@ class ListChecksTestCase(BaseTestCase):
self.assertContains(r, "Alice") self.assertContains(r, "Alice")
def test_it_works_with_tags_param(self): def test_it_works_with_tags_param(self):
r = self.client.get("/api/v1/checks/?tag=a2-tag", HTTP_X_API_KEY="abc")
r = self.client.get("/api/v1/checks/?tag=a2-tag", HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
doc = r.json() doc = r.json()
@ -96,7 +96,7 @@ class ListChecksTestCase(BaseTestCase):
self.assertEqual(check["tags"], "a2-tag") self.assertEqual(check["tags"], "a2-tag")
def test_it_filters_with_multiple_tags_param(self): def test_it_filters_with_multiple_tags_param(self):
r = self.client.get("/api/v1/checks/?tag=a1-tag&tag=a1-additional-tag", HTTP_X_API_KEY="abc")
r = self.client.get("/api/v1/checks/?tag=a1-tag&tag=a1-additional-tag", HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
doc = r.json() doc = r.json()
@ -109,7 +109,7 @@ class ListChecksTestCase(BaseTestCase):
self.assertEqual(check["tags"], "a1-tag a1-additional-tag") self.assertEqual(check["tags"], "a1-tag a1-additional-tag")
def test_it_does_not_match_tag_partially(self): def test_it_does_not_match_tag_partially(self):
r = self.client.get("/api/v1/checks/?tag=tag", HTTP_X_API_KEY="abc")
r = self.client.get("/api/v1/checks/?tag=tag", HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
doc = r.json() doc = r.json()
@ -117,7 +117,7 @@ class ListChecksTestCase(BaseTestCase):
self.assertEqual(len(doc["checks"]), 0) self.assertEqual(len(doc["checks"]), 0)
def test_non_existing_tags_filter_returns_empty_result(self): def test_non_existing_tags_filter_returns_empty_result(self):
r = self.client.get("/api/v1/checks/?tag=non_existing_tag_with_no_checks", HTTP_X_API_KEY="abc")
r = self.client.get("/api/v1/checks/?tag=non_existing_tag_with_no_checks", HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
doc = r.json() doc = r.json()


+ 5
- 5
hc/api/tests/test_pause.py View File

@ -10,7 +10,7 @@ class PauseTestCase(BaseTestCase):
url = "/api/v1/checks/%s/pause" % check.code url = "/api/v1/checks/%s/pause" % check.code
r = self.client.post(url, "", content_type="application/json", r = self.client.post(url, "", content_type="application/json",
HTTP_X_API_KEY="abc")
HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
@ -20,7 +20,7 @@ class PauseTestCase(BaseTestCase):
def test_it_only_allows_post(self): def test_it_only_allows_post(self):
url = "/api/v1/checks/1659718b-21ad-4ed1-8740-43afc6c41524/pause" url = "/api/v1/checks/1659718b-21ad-4ed1-8740-43afc6c41524/pause"
r = self.client.get(url, HTTP_X_API_KEY="abc")
r = self.client.get(url, HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 405) self.assertEqual(r.status_code, 405)
def test_it_validates_ownership(self): def test_it_validates_ownership(self):
@ -29,20 +29,20 @@ class PauseTestCase(BaseTestCase):
url = "/api/v1/checks/%s/pause" % check.code url = "/api/v1/checks/%s/pause" % check.code
r = self.client.post(url, "", content_type="application/json", r = self.client.post(url, "", content_type="application/json",
HTTP_X_API_KEY="abc")
HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 403) self.assertEqual(r.status_code, 403)
def test_it_validates_uuid(self): def test_it_validates_uuid(self):
url = "/api/v1/checks/not-uuid/pause" url = "/api/v1/checks/not-uuid/pause"
r = self.client.post(url, "", content_type="application/json", r = self.client.post(url, "", content_type="application/json",
HTTP_X_API_KEY="abc")
HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 404) self.assertEqual(r.status_code, 404)
def test_it_handles_missing_check(self): def test_it_handles_missing_check(self):
url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02/pause" url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02/pause"
r = self.client.post(url, "", content_type="application/json", r = self.client.post(url, "", content_type="application/json",
HTTP_X_API_KEY="abc")
HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 404) self.assertEqual(r.status_code, 404)

+ 7
- 9
hc/api/tests/test_update_check.py View File

@ -1,5 +1,3 @@
import json
from hc.api.models import Channel, Check from hc.api.models import Channel, Check
from hc.test import BaseTestCase from hc.test import BaseTestCase
@ -17,7 +15,7 @@ class UpdateCheckTestCase(BaseTestCase):
def test_it_works(self): def test_it_works(self):
r = self.post(self.check.code, { r = self.post(self.check.code, {
"api_key": "abc",
"api_key": "X" * 32,
"name": "Foo", "name": "Foo",
"tags": "bar,baz", "tags": "bar,baz",
"timeout": 3600, "timeout": 3600,
@ -51,7 +49,7 @@ class UpdateCheckTestCase(BaseTestCase):
self.check.assign_all_channels() self.check.assign_all_channels()
r = self.post(self.check.code, { r = self.post(self.check.code, {
"api_key": "abc",
"api_key": "X" * 32,
"channels": "" "channels": ""
}) })
@ -61,23 +59,23 @@ class UpdateCheckTestCase(BaseTestCase):
def test_it_requires_post(self): def test_it_requires_post(self):
url = "/api/v1/checks/%s" % self.check.code url = "/api/v1/checks/%s" % self.check.code
r = self.client.get(url, HTTP_X_API_KEY="abc")
r = self.client.get(url, HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 405) self.assertEqual(r.status_code, 405)
def test_it_handles_invalid_uuid(self): def test_it_handles_invalid_uuid(self):
r = self.post("not-an-uuid", {"api_key": "abc"})
r = self.post("not-an-uuid", {"api_key": "X" * 32})
self.assertEqual(r.status_code, 404) self.assertEqual(r.status_code, 404)
def test_it_handles_missing_check(self): def test_it_handles_missing_check(self):
made_up_code = "07c2f548-9850-4b27-af5d-6c9dc157ec02" made_up_code = "07c2f548-9850-4b27-af5d-6c9dc157ec02"
r = self.post(made_up_code, {"api_key": "abc"})
r = self.post(made_up_code, {"api_key": "X" * 32})
self.assertEqual(r.status_code, 404) self.assertEqual(r.status_code, 404)
def test_it_validates_ownership(self): def test_it_validates_ownership(self):
check = Check(user=self.bob, status="up") check = Check(user=self.bob, status="up")
check.save() check.save()
r = self.post(check.code, {"api_key": "abc"})
r = self.post(check.code, {"api_key": "X" * 32})
self.assertEqual(r.status_code, 403) self.assertEqual(r.status_code, 403)
def test_it_updates_cron_to_simple(self): def test_it_updates_cron_to_simple(self):
@ -85,7 +83,7 @@ class UpdateCheckTestCase(BaseTestCase):
self.check.schedule = "5 * * * *" self.check.schedule = "5 * * * *"
self.check.save() self.check.save()
r = self.post(self.check.code, {"api_key": "abc", "timeout": 3600})
r = self.post(self.check.code, {"api_key": "X" * 32, "timeout": 3600})
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
self.check.refresh_from_db() self.check.refresh_from_db()


+ 3
- 2
hc/api/views.py View File

@ -88,8 +88,8 @@ def _update(check, spec):
@csrf_exempt @csrf_exempt
@check_api_key
@validate_json(schemas.check) @validate_json(schemas.check)
@check_api_key
def checks(request): def checks(request):
if request.method == "GET": if request.method == "GET":
q = Check.objects.filter(user=request.user) q = Check.objects.filter(user=request.user)
@ -127,8 +127,8 @@ def checks(request):
@csrf_exempt @csrf_exempt
@check_api_key
@validate_json(schemas.check) @validate_json(schemas.check)
@check_api_key
def update(request, code): def update(request, code):
check = get_object_or_404(Check, code=code) check = get_object_or_404(Check, code=code)
if check.user != request.user: if check.user != request.user:
@ -149,6 +149,7 @@ def update(request, code):
@csrf_exempt @csrf_exempt
@require_POST @require_POST
@validate_json()
@check_api_key @check_api_key
def pause(request, code): def pause(request, code):
check = get_object_or_404(Check, code=code) check = get_object_or_404(Check, code=code)


+ 1
- 1
hc/test.py View File

@ -14,7 +14,7 @@ class BaseTestCase(TestCase):
self.alice.set_password("password") self.alice.set_password("password")
self.alice.save() self.alice.save()
self.profile = Profile(user=self.alice, api_key="abc")
self.profile = Profile(user=self.alice, api_key="X" * 32)
self.profile.sms_limit = 50 self.profile.sms_limit = 50
self.profile.save() self.profile.save()


Loading…
Cancel
Save