Browse Source

Add CORS support to API endpoints. Fixes #208

pull/211/head
Pēteris Caune 6 years ago
parent
commit
440a143dd6
No known key found for this signature in database GPG Key ID: E28D7679E9A9EDE2
10 changed files with 84 additions and 10 deletions
  1. +1
    -0
      CHANGELOG.md
  2. +26
    -1
      hc/api/decorators.py
  3. +10
    -0
      hc/api/tests/test_badge.py
  4. +6
    -0
      hc/api/tests/test_create_check.py
  5. +6
    -0
      hc/api/tests/test_delete_check.py
  6. +6
    -0
      hc/api/tests/test_list_channels.py
  7. +6
    -0
      hc/api/tests/test_list_checks.py
  8. +9
    -0
      hc/api/tests/test_pause.py
  9. +6
    -0
      hc/api/tests/test_update_check.py
  10. +8
    -9
      hc/api/views.py

+ 1
- 0
CHANGELOG.md View File

@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file.
- Set Pushover alert priorities for "down" and "up" events separately - Set Pushover alert priorities for "down" and "up" events separately
- Additional python usage examples - Additional python usage examples
- Allow simultaneous access to checks from different teams - Allow simultaneous access to checks from different teams
- Add CORS support to API endpoints
### Bug Fixes ### Bug Fixes
- Fix after-login redirects (the "?next=" query parameter) - Fix after-login redirects (the "?next=" query parameter)


+ 26
- 1
hc/api/decorators.py View File

@ -3,7 +3,7 @@ from functools import wraps
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db.models import Q from django.db.models import Q
from django.http import JsonResponse
from django.http import HttpResponse, JsonResponse
from hc.lib.jsonschema import ValidationError, validate from hc.lib.jsonschema import ValidationError, validate
@ -82,3 +82,28 @@ def validate_json(schema=None):
return f(request, *args, **kwds) return f(request, *args, **kwds)
return wrapper return wrapper
return decorator return decorator
def cors(*methods):
methods = set(methods)
methods.add("OPTIONS")
methods_str = ", ".join(methods)
def decorator(f):
@wraps(f)
def wrapper(request, *args, **kwds):
if request.method == "OPTIONS":
# Handle OPTIONS here
response = HttpResponse(status=204)
elif request.method in methods:
response = f(request, *args, **kwds)
else:
response = HttpResponse(status=405)
response["Access-Control-Allow-Origin"] = "*"
response["Access-Control-Allow-Headers"] = "X-Api-Key"
response["Access-Control-Allow-Methods"] = methods_str
return response
return wrapper
return decorator

+ 10
- 0
hc/api/tests/test_badge.py View File

@ -21,4 +21,14 @@ class BadgeTestCase(BaseTestCase):
url = "/badge/%s/%s/foo.svg" % (self.alice.username, sig) url = "/badge/%s/%s/foo.svg" % (self.alice.username, sig)
r = self.client.get(url) r = self.client.get(url)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")
self.assertContains(r, "#4c1") self.assertContains(r, "#4c1")
def test_it_handles_options(self):
sig = base64_hmac(str(self.alice.username), "foo", settings.SECRET_KEY)
sig = sig[:8]
url = "/badge/%s/%s/foo.svg" % (self.alice.username, sig)
r = self.client.options(url)
self.assertEqual(r.status_code, 204)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")

+ 6
- 0
hc/api/tests/test_create_check.py View File

@ -30,6 +30,7 @@ class CreateCheckTestCase(BaseTestCase):
}) })
self.assertEqual(r.status_code, 201) self.assertEqual(r.status_code, 201)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")
doc = r.json() doc = r.json()
assert "ping_url" in doc assert "ping_url" in doc
@ -47,6 +48,11 @@ class CreateCheckTestCase(BaseTestCase):
self.assertEqual(check.timeout.total_seconds(), 3600) self.assertEqual(check.timeout.total_seconds(), 3600)
self.assertEqual(check.grace.total_seconds(), 60) self.assertEqual(check.grace.total_seconds(), 60)
def test_it_handles_options(self):
r = self.client.options(self.URL)
self.assertEqual(r.status_code, 204)
self.assertIn("POST", r["Access-Control-Allow-Methods"])
def test_30_days_works(self): def test_30_days_works(self):
r = self.post({ r = self.post({
"api_key": "X" * 32, "api_key": "X" * 32,


+ 6
- 0
hc/api/tests/test_delete_check.py View File

@ -13,6 +13,7 @@ class DeleteCheckTestCase(BaseTestCase):
r = self.client.delete("/api/v1/checks/%s" % self.check.code, r = self.client.delete("/api/v1/checks/%s" % self.check.code,
HTTP_X_API_KEY="X" * 32) HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")
# It should be gone-- # It should be gone--
self.assertFalse(Check.objects.filter(code=self.check.code).exists()) self.assertFalse(Check.objects.filter(code=self.check.code).exists())
@ -21,3 +22,8 @@ class DeleteCheckTestCase(BaseTestCase):
url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02" url = "/api/v1/checks/07c2f548-9850-4b27-af5d-6c9dc157ec02"
r = self.client.delete(url, HTTP_X_API_KEY="X" * 32) r = self.client.delete(url, HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 404) self.assertEqual(r.status_code, 404)
def test_it_handles_options(self):
r = self.client.options("/api/v1/checks/%s" % self.check.code)
self.assertEqual(r.status_code, 204)
self.assertIn("DELETE", r["Access-Control-Allow-Methods"])

+ 6
- 0
hc/api/tests/test_list_channels.py View File

@ -20,6 +20,7 @@ class ListChannelsTestCase(BaseTestCase):
def test_it_works(self): def test_it_works(self):
r = self.get() r = self.get()
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")
doc = r.json() doc = r.json()
self.assertEqual(len(doc["channels"]), 1) self.assertEqual(len(doc["channels"]), 1)
@ -29,6 +30,11 @@ class ListChannelsTestCase(BaseTestCase):
self.assertEqual(c["kind"], "email") self.assertEqual(c["kind"], "email")
self.assertEqual(c["name"], "Email to Alice") self.assertEqual(c["name"], "Email to Alice")
def test_it_handles_options(self):
r = self.client.options("/api/v1/channels/")
self.assertEqual(r.status_code, 204)
self.assertIn("GET", r["Access-Control-Allow-Methods"])
def test_it_shows_only_users_channels(self): def test_it_shows_only_users_channels(self):
Channel.objects.create(user=self.bob, kind="email", name="Bob") Channel.objects.create(user=self.bob, kind="email", name="Bob")


+ 6
- 0
hc/api/tests/test_list_checks.py View File

@ -40,6 +40,7 @@ class ListChecksTestCase(BaseTestCase):
def test_it_works(self): def test_it_works(self):
r = self.get() r = self.get()
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")
doc = r.json() doc = r.json()
self.assertEqual(len(doc["checks"]), 2) self.assertEqual(len(doc["checks"]), 2)
@ -73,6 +74,11 @@ class ListChecksTestCase(BaseTestCase):
self.assertEqual(a2["ping_url"], self.a2.url()) self.assertEqual(a2["ping_url"], self.a2.url())
self.assertEqual(a2["status"], "up") self.assertEqual(a2["status"], "up")
def test_it_handles_options(self):
r = self.client.options("/api/v1/checks/")
self.assertEqual(r.status_code, 204)
self.assertIn("GET", r["Access-Control-Allow-Methods"])
def test_it_shows_only_users_checks(self): def test_it_shows_only_users_checks(self):
bobs_check = Check(user=self.bob, name="Bob 1") bobs_check = Check(user=self.bob, name="Bob 1")
bobs_check.save() bobs_check.save()


+ 9
- 0
hc/api/tests/test_pause.py View File

@ -13,10 +13,19 @@ class PauseTestCase(BaseTestCase):
HTTP_X_API_KEY="X" * 32) HTTP_X_API_KEY="X" * 32)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")
check.refresh_from_db() check.refresh_from_db()
self.assertEqual(check.status, "paused") self.assertEqual(check.status, "paused")
def test_it_handles_options(self):
check = Check(user=self.alice, status="up")
check.save()
r = self.client.options("/api/v1/checks/%s/pause" % check.code)
self.assertEqual(r.status_code, 204)
self.assertIn("POST", r["Access-Control-Allow-Methods"])
def test_it_only_allows_post(self): def test_it_only_allows_post(self):
url = "/api/v1/checks/1659718b-21ad-4ed1-8740-43afc6c41524/pause" url = "/api/v1/checks/1659718b-21ad-4ed1-8740-43afc6c41524/pause"


+ 6
- 0
hc/api/tests/test_update_check.py View File

@ -25,6 +25,7 @@ class UpdateCheckTestCase(BaseTestCase):
}) })
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
self.assertEqual(r["Access-Control-Allow-Origin"], "*")
doc = r.json() doc = r.json()
assert "ping_url" in doc assert "ping_url" in doc
@ -44,6 +45,11 @@ class UpdateCheckTestCase(BaseTestCase):
self.assertEqual(self.check.timeout.total_seconds(), 3600) self.assertEqual(self.check.timeout.total_seconds(), 3600)
self.assertEqual(self.check.grace.total_seconds(), 60) self.assertEqual(self.check.grace.total_seconds(), 60)
def test_it_handles_options(self):
r = self.client.options("/api/v1/checks/%s" % self.check.code)
self.assertEqual(r.status_code, 204)
self.assertIn("POST", r["Access-Control-Allow-Methods"])
def test_it_unassigns_channels(self): def test_it_unassigns_channels(self):
Channel.objects.create(user=self.alice) Channel.objects.create(user=self.alice)
self.check.assign_all_channels() self.check.assign_all_channels()


+ 8
- 9
hc/api/views.py View File

@ -10,10 +10,9 @@ from django.shortcuts import get_object_or_404
from django.utils import timezone from django.utils import timezone
from django.views.decorators.cache import never_cache from django.views.decorators.cache import never_cache
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_GET, require_POST
from hc.api import schemas from hc.api import schemas
from hc.api.decorators import authorize, authorize_read, validate_json
from hc.api.decorators import authorize, authorize_read, cors, validate_json
from hc.api.models import Check, Notification, Channel from hc.api.models import Check, Notification, Channel
from hc.lib.badges import check_signature, get_badge_svg from hc.lib.badges import check_signature, get_badge_svg
@ -140,17 +139,15 @@ def create_check(request):
@csrf_exempt @csrf_exempt
@cors("GET", "POST")
def checks(request): def checks(request):
if request.method == "GET":
return get_checks(request)
elif request.method == "POST":
if request.method == "POST":
return create_check(request) return create_check(request)
return HttpResponse(status=405)
return get_checks(request)
@require_GET
@cors("GET")
@validate_json() @validate_json()
@authorize_read @authorize_read
def channels(request): def channels(request):
@ -160,6 +157,7 @@ def channels(request):
@csrf_exempt @csrf_exempt
@cors("POST", "DELETE")
@validate_json(schemas.check) @validate_json(schemas.check)
@authorize @authorize
def update(request, code): def update(request, code):
@ -180,8 +178,8 @@ def update(request, code):
return HttpResponse(status=405) return HttpResponse(status=405)
@cors("POST")
@csrf_exempt @csrf_exempt
@require_POST
@validate_json() @validate_json()
@authorize @authorize
def pause(request, code): def pause(request, code):
@ -195,6 +193,7 @@ def pause(request, code):
@never_cache @never_cache
@cors("GET")
def badge(request, username, signature, tag, format="svg"): def badge(request, username, signature, tag, format="svg"):
if not check_signature(username, tag, signature): if not check_signature(username, tag, signature):
return HttpResponseNotFound() return HttpResponseNotFound()


Loading…
Cancel
Save