diff --git a/hc/api/schemas.py b/hc/api/schemas.py index 3db4282d..54787071 100644 --- a/hc/api/schemas.py +++ b/hc/api/schemas.py @@ -1,9 +1,14 @@ check = { + "type": "object", "properties": { "name": {"type": "string"}, "tags": {"type": "string"}, "timeout": {"type": "number", "minimum": 60, "maximum": 604800}, "grace": {"type": "number", "minimum": 60, "maximum": 604800}, - "channels": {"type": "string"} + "channels": {"type": "string"}, + "unique": { + "type": "array", + "items": {"enum": ["name", "tags", "timeout", "grace"]} + } } } diff --git a/hc/api/tests/test_create_check.py b/hc/api/tests/test_create_check.py index fcc815a7..e787ee12 100644 --- a/hc/api/tests/test_create_check.py +++ b/hc/api/tests/test_create_check.py @@ -7,10 +7,7 @@ from hc.test import BaseTestCase class CreateCheckTestCase(BaseTestCase): URL = "/api/v1/checks/" - def setUp(self): - super(CreateCheckTestCase, self).setUp() - - def post(self, data, expected_error=None): + def post(self, data, expected_error=None, expected_fragment=None): r = self.client.post(self.URL, json.dumps(data), content_type="application/json") @@ -18,6 +15,10 @@ class CreateCheckTestCase(BaseTestCase): self.assertEqual(r.status_code, 400) self.assertEqual(r.json()["error"], expected_error) + if expected_fragment: + self.assertEqual(r.status_code, 400) + self.assertIn(expected_fragment, r.json()["error"]) + return r def test_it_works(self): @@ -63,6 +64,22 @@ class CreateCheckTestCase(BaseTestCase): check = Check.objects.get() self.assertEqual(check.channel_set.get(), channel) + def test_it_supports_unique(self): + existing = Check(user=self.alice, name="Foo") + existing.save() + + r = self.post({ + "api_key": "abc", + "name": "Foo", + "unique": ["name"] + }) + + # Expect 200 instead of 201 + self.assertEqual(r.status_code, 200) + + # And there should be only one check in the database: + self.assertEqual(Check.objects.count(), 1) + def test_it_handles_missing_request_body(self): r = self.client.post(self.URL, content_type="application/json") self.assertEqual(r.status_code, 400) @@ -80,16 +97,33 @@ class CreateCheckTestCase(BaseTestCase): def test_it_rejects_small_timeout(self): self.post({"api_key": "abc", "timeout": 0}, - expected_error="timeout is too small") + expected_fragment="timeout is too small") def test_it_rejects_large_timeout(self): self.post({"api_key": "abc", "timeout": 604801}, - expected_error="timeout is too large") + expected_fragment="timeout is too large") def test_it_rejects_non_number_timeout(self): self.post({"api_key": "abc", "timeout": "oops"}, - expected_error="timeout is not a number") + expected_fragment="timeout is not a number") def test_it_rejects_non_string_name(self): self.post({"api_key": "abc", "name": False}, - expected_error="name is not a string") + expected_fragment="name is not a string") + + def test_unique_accepts_only_whitelisted_values(self): + existing = Check(user=self.alice, name="Foo") + existing.save() + + self.post({ + "api_key": "abc", + "name": "Foo", + "unique": ["status"] + }, expected_fragment="unexpected value") + + def test_it_rejects_bad_unique_values(self): + self.post({ + "api_key": "abc", + "name": "Foo", + "unique": "not a list" + }, expected_fragment="not an array") diff --git a/hc/api/views.py b/hc/api/views.py index c81ea2d1..5b7ca270 100644 --- a/hc/api/views.py +++ b/hc/api/views.py @@ -9,7 +9,7 @@ from django.views.decorators.csrf import csrf_exempt from hc.api import schemas from hc.api.decorators import check_api_key, uuid_or_400, validate_json -from hc.api.models import Check, Ping +from hc.api.models import Check, Ping, DEFAULT_TIMEOUT, DEFAULT_GRACE from hc.lib.badges import check_signature, get_badge_svg @@ -56,37 +56,36 @@ def checks(request): return JsonResponse(doc) elif request.method == "POST": - - unique_fields = request.json.get("unique", []) name = str(request.json.get("name", "")) + tags = str(request.json.get("tags", "")) - if len(unique_fields) > 0: - existing_checks = Check.objects.filter(user=request.user) - - for unique_field in unique_fields: - - field_value = request.json.get(unique_field) + timeout = DEFAULT_TIMEOUT + if "timeout" in request.json: + timeout = td(seconds=request.json["timeout"]) - if unique_field == "timeout" or unique_field == "grace": - field_value = td(seconds=field_value) + grace = DEFAULT_GRACE + if "grace" in request.json: + grace = td(seconds=request.json["grace"]) - try: - existing_checks = existing_checks.filter(**{unique_field: field_value}) - except FieldError: - return HttpResponse(status=400) + unique_fields = request.json.get("unique", []) + if unique_fields: + existing_checks = Check.objects.filter(user=request.user) + if "name" in unique_fields: + existing_checks = existing_checks.filter(name=name) + if "tags" in unique_fields: + existing_checks = existing_checks.filter(tags=tags) + if "timeout" in unique_fields: + existing_checks = existing_checks.filter(timeout=timeout) + if "grace" in unique_fields: + existing_checks = existing_checks.filter(grace=grace) if existing_checks.count() > 0: - # There might be more than one check with the same name since name - # uniqueness isn't enforced in the model - return JsonResponse(existing_checks.first().to_dict(), status=200) + # There might be more than one matching check, return first + first_match = existing_checks.first() + return JsonResponse(first_match.to_dict(), status=200) - check = Check(user=request.user) - check.name = name - check.tags = str(request.json.get("tags", "")) - if "timeout" in request.json: - check.timeout = td(seconds=request.json["timeout"]) - if "grace" in request.json: - check.grace = td(seconds=request.json["grace"]) + check = Check(user=request.user, name=name, tags=tags, + timeout=timeout, grace=grace) check.save()