diff --git a/hc/accounts/forms.py b/hc/accounts/forms.py index fd85c836..a03443f7 100644 --- a/hc/accounts/forms.py +++ b/hc/accounts/forms.py @@ -1,7 +1,9 @@ import base64 +import binascii from datetime import timedelta as td from django import forms +from django.core.exceptions import ValidationError from django.contrib.auth import authenticate from django.contrib.auth.models import User from hc.api.models import TokenBucket @@ -13,6 +15,17 @@ class LowercaseEmailField(forms.EmailField): return value.lower() +class Base64Field(forms.CharField): + def to_python(self, value): + if value is None: + return None + + try: + return base64.b64decode(value.encode()) + except binascii.Error: + raise ValidationError(message="Cannot decode base64") + + class AvailableEmailForm(forms.Form): # Call it "identity" instead of "email" # to avoid some of the dumber bots @@ -109,7 +122,7 @@ class RemoveTeamMemberForm(forms.Form): class ProjectNameForm(forms.Form): - name = forms.CharField(max_length=60, required=True) + name = forms.CharField(max_length=60) class TransferForm(forms.Form): @@ -118,36 +131,12 @@ class TransferForm(forms.Form): class AddCredentialForm(forms.Form): name = forms.CharField(max_length=100) - client_data_json = forms.CharField(required=True) - attestation_object = forms.CharField(required=True) - - def clean_client_data_json(self): - v = self.cleaned_data["client_data_json"] - return base64.b64decode(v.encode()) - - def clean_attestation_object(self): - v = self.cleaned_data["attestation_object"] - return base64.b64decode(v.encode()) + client_data_json = Base64Field() + attestation_object = Base64Field() class LoginTfaForm(forms.Form): - credential_id = forms.CharField(required=True) - client_data_json = forms.CharField(required=True) - authenticator_data = forms.CharField(required=True) - signature = forms.CharField(required=True) - - def clean_credential_id(self): - v = self.cleaned_data["credential_id"] - return base64.b64decode(v.encode()) - - def clean_client_data_json(self): - v = self.cleaned_data["client_data_json"] - return base64.b64decode(v.encode()) - - def clean_authenticator_data(self): - v = self.cleaned_data["authenticator_data"] - return base64.b64decode(v.encode()) - - def clean_signature(self): - v = self.cleaned_data["signature"] - return base64.b64decode(v.encode()) + credential_id = Base64Field() + client_data_json = Base64Field() + authenticator_data = Base64Field() + signature = Base64Field() diff --git a/hc/accounts/tests/test_add_credential.py b/hc/accounts/tests/test_add_credential.py index 062118e0..aa9f9f59 100644 --- a/hc/accounts/tests/test_add_credential.py +++ b/hc/accounts/tests/test_add_credential.py @@ -52,4 +52,27 @@ class AddCredentialTestCase(BaseTestCase): c = Credential.objects.get() self.assertEqual(c.name, "My New Key") - # FIXME: test authentication failure + def test_it_rejects_bad_base64(self): + self.client.login(username="alice@example.org", password="password") + self._set_sudo_flag() + + payload = { + "name": "My New Key", + "client_data_json": "not valid base64", + "attestation_object": "not valid base64", + } + + r = self.client.post(self.url, payload) + self.assertEqual(r.status_code, 400) + + def test_it_requires_client_data_json(self): + self.client.login(username="alice@example.org", password="password") + self._set_sudo_flag() + + payload = { + "name": "My New Key", + "attestation_object": "e30=", + } + + r = self.client.post(self.url, payload) + self.assertEqual(r.status_code, 400)