diff --git a/CHANGELOG.md b/CHANGELOG.md index ea0e6587..801be26f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. - Set Pushover alert priorities for "down" and "up" events separately - Additional python usage examples - Allow simultaneous access to checks from different teams +- Add CORS support to API endpoints ### Bug Fixes - Fix after-login redirects (the "?next=" query parameter) diff --git a/hc/api/decorators.py b/hc/api/decorators.py index 71dd4130..bb412582 100644 --- a/hc/api/decorators.py +++ b/hc/api/decorators.py @@ -3,7 +3,7 @@ from functools import wraps from django.contrib.auth.models import User from django.db.models import Q -from django.http import JsonResponse +from django.http import HttpResponse, JsonResponse from hc.lib.jsonschema import ValidationError, validate @@ -82,3 +82,28 @@ def validate_json(schema=None): return f(request, *args, **kwds) return wrapper return decorator + + +def cors(*methods): + methods = set(methods) + methods.add("OPTIONS") + methods_str = ", ".join(methods) + + def decorator(f): + @wraps(f) + def wrapper(request, *args, **kwds): + if request.method == "OPTIONS": + # Handle OPTIONS here + response = HttpResponse(status=204) + elif request.method in methods: + response = f(request, *args, **kwds) + else: + response = HttpResponse(status=405) + + response["Access-Control-Allow-Origin"] = "*" + response["Access-Control-Allow-Headers"] = "X-Api-Key" + response["Access-Control-Allow-Methods"] = methods_str + return response + + return wrapper + return decorator diff --git a/hc/api/tests/test_badge.py b/hc/api/tests/test_badge.py index 955478fa..394d9bcb 100644 --- a/hc/api/tests/test_badge.py +++ b/hc/api/tests/test_badge.py @@ -21,4 +21,14 @@ class BadgeTestCase(BaseTestCase): url = "/badge/%s/%s/foo.svg" % (self.alice.username, sig) r = self.client.get(url) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") self.assertContains(r, "#4c1") + + def test_it_handles_options(self): + sig = base64_hmac(str(self.alice.username), "foo", settings.SECRET_KEY) + sig = sig[:8] + url = "/badge/%s/%s/foo.svg" % (self.alice.username, sig) + + r = self.client.options(url) + self.assertEqual(r.status_code, 204) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") diff --git a/hc/api/tests/test_create_check.py b/hc/api/tests/test_create_check.py index 00bf10e7..79376b20 100644 --- a/hc/api/tests/test_create_check.py +++ b/hc/api/tests/test_create_check.py @@ -30,6 +30,7 @@ class CreateCheckTestCase(BaseTestCase): }) self.assertEqual(r.status_code, 201) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") doc = r.json() assert "ping_url" in doc @@ -47,6 +48,11 @@ class CreateCheckTestCase(BaseTestCase): self.assertEqual(check.timeout.total_seconds(), 3600) self.assertEqual(check.grace.total_seconds(), 60) + def test_it_handles_options(self): + r = self.client.options(self.URL) + self.assertEqual(r.status_code, 204) + self.assertIn("POST", r["Access-Control-Allow-Methods"]) + def test_30_days_works(self): r = self.post({ "api_key": "X" * 32, diff --git a/hc/api/tests/test_delete_check.py b/hc/api/tests/test_delete_check.py index 0531ae3e..d6680ac4 100644 --- a/hc/api/tests/test_delete_check.py +++ b/hc/api/tests/test_delete_check.py @@ -13,6 +13,7 @@ class DeleteCheckTestCase(BaseTestCase): r = self.client.delete("/api/v1/checks/%s" % self.check.code, HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") # It should be gone-- self.assertFalse(Check.objects.filter(code=self.check.code).exists()) @@ -21,3 +22,8 @@ class DeleteCheckTestCase(BaseTestCase): url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02" r = self.client.delete(url, HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 404) + + def test_it_handles_options(self): + r = self.client.options("/api/v1/checks/%s" % self.check.code) + self.assertEqual(r.status_code, 204) + self.assertIn("DELETE", r["Access-Control-Allow-Methods"]) diff --git a/hc/api/tests/test_list_channels.py b/hc/api/tests/test_list_channels.py index 4f884b02..7d755097 100644 --- a/hc/api/tests/test_list_channels.py +++ b/hc/api/tests/test_list_channels.py @@ -20,6 +20,7 @@ class ListChannelsTestCase(BaseTestCase): def test_it_works(self): r = self.get() self.assertEqual(r.status_code, 200) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") doc = r.json() self.assertEqual(len(doc["channels"]), 1) @@ -29,6 +30,11 @@ class ListChannelsTestCase(BaseTestCase): self.assertEqual(c["kind"], "email") self.assertEqual(c["name"], "Email to Alice") + def test_it_handles_options(self): + r = self.client.options("/api/v1/channels/") + self.assertEqual(r.status_code, 204) + self.assertIn("GET", r["Access-Control-Allow-Methods"]) + def test_it_shows_only_users_channels(self): Channel.objects.create(user=self.bob, kind="email", name="Bob") diff --git a/hc/api/tests/test_list_checks.py b/hc/api/tests/test_list_checks.py index fac54368..0b289954 100644 --- a/hc/api/tests/test_list_checks.py +++ b/hc/api/tests/test_list_checks.py @@ -40,6 +40,7 @@ class ListChecksTestCase(BaseTestCase): def test_it_works(self): r = self.get() self.assertEqual(r.status_code, 200) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") doc = r.json() self.assertEqual(len(doc["checks"]), 2) @@ -73,6 +74,11 @@ class ListChecksTestCase(BaseTestCase): self.assertEqual(a2["ping_url"], self.a2.url()) self.assertEqual(a2["status"], "up") + def test_it_handles_options(self): + r = self.client.options("/api/v1/checks/") + self.assertEqual(r.status_code, 204) + self.assertIn("GET", r["Access-Control-Allow-Methods"]) + def test_it_shows_only_users_checks(self): bobs_check = Check(user=self.bob, name="Bob 1") bobs_check.save() diff --git a/hc/api/tests/test_pause.py b/hc/api/tests/test_pause.py index abe2a7b1..e6c3d414 100644 --- a/hc/api/tests/test_pause.py +++ b/hc/api/tests/test_pause.py @@ -13,10 +13,19 @@ class PauseTestCase(BaseTestCase): HTTP_X_API_KEY="X" * 32) self.assertEqual(r.status_code, 200) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") check.refresh_from_db() self.assertEqual(check.status, "paused") + def test_it_handles_options(self): + check = Check(user=self.alice, status="up") + check.save() + + r = self.client.options("/api/v1/checks/%s/pause" % check.code) + self.assertEqual(r.status_code, 204) + self.assertIn("POST", r["Access-Control-Allow-Methods"]) + def test_it_only_allows_post(self): url = "/api/v1/checks/1659718b-21ad-4ed1-8740-43afc6c41524/pause" diff --git a/hc/api/tests/test_update_check.py b/hc/api/tests/test_update_check.py index 70ae0399..24b41c2d 100644 --- a/hc/api/tests/test_update_check.py +++ b/hc/api/tests/test_update_check.py @@ -25,6 +25,7 @@ class UpdateCheckTestCase(BaseTestCase): }) self.assertEqual(r.status_code, 200) + self.assertEqual(r["Access-Control-Allow-Origin"], "*") doc = r.json() assert "ping_url" in doc @@ -44,6 +45,11 @@ class UpdateCheckTestCase(BaseTestCase): self.assertEqual(self.check.timeout.total_seconds(), 3600) self.assertEqual(self.check.grace.total_seconds(), 60) + def test_it_handles_options(self): + r = self.client.options("/api/v1/checks/%s" % self.check.code) + self.assertEqual(r.status_code, 204) + self.assertIn("POST", r["Access-Control-Allow-Methods"]) + def test_it_unassigns_channels(self): Channel.objects.create(user=self.alice) self.check.assign_all_channels() diff --git a/hc/api/views.py b/hc/api/views.py index efa2a452..bcc32cee 100644 --- a/hc/api/views.py +++ b/hc/api/views.py @@ -10,10 +10,9 @@ from django.shortcuts import get_object_or_404 from django.utils import timezone from django.views.decorators.cache import never_cache from django.views.decorators.csrf import csrf_exempt -from django.views.decorators.http import require_GET, require_POST from hc.api import schemas -from hc.api.decorators import authorize, authorize_read, validate_json +from hc.api.decorators import authorize, authorize_read, cors, validate_json from hc.api.models import Check, Notification, Channel from hc.lib.badges import check_signature, get_badge_svg @@ -140,17 +139,15 @@ def create_check(request): @csrf_exempt +@cors("GET", "POST") def checks(request): - if request.method == "GET": - return get_checks(request) - - elif request.method == "POST": + if request.method == "POST": return create_check(request) - return HttpResponse(status=405) + return get_checks(request) -@require_GET +@cors("GET") @validate_json() @authorize_read def channels(request): @@ -160,6 +157,7 @@ def channels(request): @csrf_exempt +@cors("POST", "DELETE") @validate_json(schemas.check) @authorize def update(request, code): @@ -180,8 +178,8 @@ def update(request, code): return HttpResponse(status=405) +@cors("POST") @csrf_exempt -@require_POST @validate_json() @authorize def pause(request, code): @@ -195,6 +193,7 @@ def pause(request, code): @never_cache +@cors("GET") def badge(request, username, signature, tag, format="svg"): if not check_signature(username, tag, signature): return HttpResponseNotFound()