diff --git a/CHANGELOG.md b/CHANGELOG.md index 5060cf61..23047a9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ All notable changes to this project will be documented in this file. - API: don't let SuspiciousOperation bubble up when validating channel ids - API security: check channel ownership when setting check's channels - API: update check's "alert_after" field when changing schedule +- API: validate channel identifiers before creating/updating a check (#335) ## v1.13.0 - 2020-02-13 diff --git a/hc/api/tests/test_create_check.py b/hc/api/tests/test_create_check.py index 9dd80b59..9f8e3d96 100644 --- a/hc/api/tests/test_create_check.py +++ b/hc/api/tests/test_create_check.py @@ -84,6 +84,14 @@ class CreateCheckTestCase(BaseTestCase): check = Check.objects.get() self.assertEqual(check.channel_set.get(), channel) + def test_it_rejects_bad_channel_code(self): + r = self.post({"api_key": "X" * 32, "channels": "abc"}) + self.assertEqual(r.status_code, 400) + self.assertEqual(r.json()["error"], "invalid channel identifier: abc") + + # The check should not have been saved + self.assertFalse(Check.objects.exists()) + def test_it_supports_unique(self): Check.objects.create(project=self.project, name="Foo") diff --git a/hc/api/tests/test_update_check.py b/hc/api/tests/test_update_check.py index f689f193..a46eec39 100644 --- a/hc/api/tests/test_update_check.py +++ b/hc/api/tests/test_update_check.py @@ -119,6 +119,15 @@ class UpdateCheckTestCase(BaseTestCase): self.assertEqual(self.check.channel_set.count(), 1) self.assertEqual(self.check.channel_set.first().code, channel.code) + def test_it_sets_the_channel_only_once(self): + channel = Channel.objects.create(project=self.project) + duplicates = "%s,%s" % (channel.code, channel.code) + r = self.post(self.check.code, {"api_key": "X" * 32, "channels": duplicates}) + self.assertEqual(r.status_code, 200) + + self.check.refresh_from_db() + self.assertEqual(self.check.channel_set.count(), 1) + def test_it_handles_comma_separated_channel_codes(self): c1 = Channel.objects.create(project=self.project) c2 = Channel.objects.create(project=self.project) @@ -152,10 +161,15 @@ class UpdateCheckTestCase(BaseTestCase): self.assertEqual(check.channel_set.count(), 1) def test_it_rejects_bad_channel_code(self): - r = self.post(self.check.code, {"api_key": "X" * 32, "channels": "abc"}) + payload = {"api_key": "X" * 32, "channels": "abc", "name": "New Name"} + r = self.post(self.check.code, payload,) self.assertEqual(r.status_code, 400) self.assertEqual(r.json()["error"], "invalid channel identifier: abc") + # The name should be unchanged + self.check.refresh_from_db() + self.assertEqual(self.check.name, "") + def test_it_rejects_missing_channel(self): code = str(uuid.uuid4()) r = self.post(self.check.code, {"api_key": "X" * 32, "channels": code}) diff --git a/hc/api/views.py b/hc/api/views.py index abcb67a8..bc675d57 100644 --- a/hc/api/views.py +++ b/hc/api/views.py @@ -67,6 +67,21 @@ def _lookup(project, spec): def _update(check, spec): + channels = set() + # First, validate the supplied channel codes + if "channels" in spec and spec["channels"] not in ("*", ""): + q = Channel.objects.filter(project=check.project) + for s in spec["channels"].split(","): + try: + code = uuid.UUID(s) + except ValueError: + raise BadChannelException("invalid channel identifier: %s" % s) + + try: + channels.add(q.get(code=code)) + except Channel.DoesNotExist: + raise BadChannelException("invalid channel identifier: %s" % s) + if "name" in spec: check.name = spec["name"] @@ -94,26 +109,12 @@ def _update(check, spec): # This needs to be done after saving the check, because of # the M2M relation between checks and channels: - if "channels" in spec: - if spec["channels"] == "*": - check.assign_all_channels() - elif spec["channels"] == "": - check.channel_set.clear() - else: - channels = [] - channel_query = Channel.objects.filter(project=check.project) - for chunk in spec["channels"].split(","): - try: - chunk = uuid.UUID(chunk) - except ValueError: - raise BadChannelException("invalid channel identifier: %s" % chunk) - - try: - channels.append(channel_query.get(code=chunk)) - except Channel.DoesNotExist: - raise BadChannelException("invalid channel identifier: %s" % chunk) - - check.channel_set.set(channels) + if spec.get("channels") == "*": + check.assign_all_channels() + elif spec.get("channels") == "": + check.channel_set.clear() + elif channels: + check.channel_set.set(channels) return check