diff --git a/hc/front/forms.py b/hc/front/forms.py index 554ec5ee..ad9ba9d0 100644 --- a/hc/front/forms.py +++ b/hc/front/forms.py @@ -1,5 +1,6 @@ -import json from datetime import timedelta as td +import json +import re from django import forms from django.core.validators import RegexValidator @@ -55,28 +56,42 @@ class AddUrlForm(forms.Form): value = forms.URLField(max_length=1000, validators=[WebhookValidator()]) +_valid_header_name = re.compile(r'\A[^:\s][^:\r\n]*\Z').match + + class AddWebhookForm(forms.Form): error_css_class = "has-error" url_down = forms.URLField(max_length=1000, required=False, - validators=[WebhookValidator()]) + validators=[WebhookValidator()]) url_up = forms.URLField(max_length=1000, required=False, - validators=[WebhookValidator()]) + validators=[WebhookValidator()]) post_data = forms.CharField(max_length=1000, required=False) def __init__(self, *args, **kwargs): super(AddWebhookForm, self).__init__(*args, **kwargs) + self.invalid_header_names = set() self.headers = {} if "header_key[]" in self.data and "header_value[]" in self.data: keys = self.data.getlist("header_key[]") values = self.data.getlist("header_value[]") for key, value in zip(keys, values): - if key: - self.headers[key] = value + if not key: + continue + + if not _valid_header_name(key): + self.invalid_header_names.add(key) + + self.headers[key] = value + + def clean(self): + if self.invalid_header_names: + raise forms.ValidationError("Invalid header names") + return self.cleaned_data def get_value(self): val = dict(self.cleaned_data) diff --git a/hc/front/tests/test_add_webhook.py b/hc/front/tests/test_add_webhook.py index 744f5686..763cf642 100644 --- a/hc/front/tests/test_add_webhook.py +++ b/hc/front/tests/test_add_webhook.py @@ -72,12 +72,27 @@ class AddWebhookTestCase(BaseTestCase): self.assertEqual(c.value, '{"headers": {}, "post_data": "hello", "url_down": "http://foo.com", "url_up": ""}') def test_it_adds_headers(self): - form = {"url_down": "http://foo.com", "header_key[]": ["test", "test2"], "header_value[]": ["123", "abc"]} + form = { + "url_down": "http://foo.com", + "header_key[]": ["test", "test2"], + "header_value[]": ["123", "abc"] + } self.client.login(username="alice@example.org", password="password") r = self.client.post(self.url, form) self.assertRedirects(r, "/integrations/") c = Channel.objects.get() - self.assertEqual(c.value, '{"headers": {"test": "123", "test2": "abc"}, "post_data": "", "url_down": "http://foo.com", "url_up": ""}') + self.assertEqual(c.headers, {"test": "123", "test2": "abc"}) + def test_it_rejects_bad_header_names(self): + self.client.login(username="alice@example.org", password="password") + form = { + "url_down": "http://example.org", + "header_key[]": ["ill:egal"], + "header_value[]": ["123"] + } + + r = self.client.post(self.url, form) + self.assertContains(r, "Please use valid HTTP header names.") + self.assertEqual(Channel.objects.count(), 0) diff --git a/static/css/add_webhook.css b/static/css/add_webhook.css new file mode 100644 index 00000000..16b00681 --- /dev/null +++ b/static/css/add_webhook.css @@ -0,0 +1,4 @@ +.webhook-header .error { + border-color: #a94442; +} + diff --git a/templates/base.html b/templates/base.html index f6b2bec9..4d9cf4e9 100644 --- a/templates/base.html +++ b/templates/base.html @@ -34,6 +34,7 @@ + diff --git a/templates/integrations/add_webhook.html b/templates/integrations/add_webhook.html index 6e952ad8..e4bcb12b 100644 --- a/templates/integrations/add_webhook.html +++ b/templates/integrations/add_webhook.html @@ -105,32 +105,36 @@ {% endif %} -