diff --git a/hc/accounts/tests/test_check_token.py b/hc/accounts/tests/test_check_token.py index 2a670772..a5f1a465 100644 --- a/hc/accounts/tests/test_check_token.py +++ b/hc/accounts/tests/test_check_token.py @@ -1,4 +1,5 @@ from django.contrib.auth.hashers import make_password +from hc.accounts.models import Credential from hc.test import BaseTestCase @@ -48,3 +49,16 @@ class CheckTokenTestCase(BaseTestCase): url = "/accounts/check_token/alice/secret-token/?next=/evil/" r = self.client.post(url) self.assertRedirects(r, self.checks_url) + + def test_it_redirects_to_login_tfa(self): + Credential.objects.create(user=self.alice, name="Alices Key") + + r = self.client.post("/accounts/check_token/alice/secret-token/") + self.assertRedirects( + r, "/accounts/login/two_factor/", fetch_redirect_response=False + ) + + # It should not log the user in yet + self.assertNotIn("_auth_user_id", self.client.session) + # Instead, it should set 2fa_user_id in the session + self.assertEqual(self.client.session["2fa_user_id"], self.alice.id) diff --git a/hc/accounts/tests/test_login.py b/hc/accounts/tests/test_login.py index 7b683b1f..277d8425 100644 --- a/hc/accounts/tests/test_login.py +++ b/hc/accounts/tests/test_login.py @@ -1,6 +1,7 @@ from django.conf import settings from django.core import mail from django.test.utils import override_settings +from hc.accounts.models import Credential from hc.api.models import Check, TokenBucket from hc.test import BaseTestCase @@ -8,7 +9,7 @@ from hc.test import BaseTestCase class LoginTestCase(BaseTestCase): def setUp(self): super().setUp() - self.checks_url = "/projects/%s/checks/" % self.project.code + self.checks_url = f"/projects/{self.project.code}/checks/" def test_it_sends_link(self): form = {"identity": "alice@example.org"} @@ -111,3 +112,17 @@ class LoginTestCase(BaseTestCase): def test_it_obeys_registration_open(self): r = self.client.get("/accounts/login/") self.assertNotContains(r, "Create Your Account") + + def test_it_redirects_to_login_tfa(self): + Credential.objects.create(user=self.alice, name="Alices Key") + + form = {"action": "login", "email": "alice@example.org", "password": "password"} + r = self.client.post("/accounts/login/", form) + self.assertRedirects( + r, "/accounts/login/two_factor/", fetch_redirect_response=False + ) + + # It should not log the user in yet + self.assertNotIn("_auth_user_id", self.client.session) + # Instead, it should set 2fa_user_id in the session + self.assertEqual(self.client.session["2fa_user_id"], self.alice.id) diff --git a/hc/accounts/tests/test_login_tfa.py b/hc/accounts/tests/test_login_tfa.py new file mode 100644 index 00000000..d162793c --- /dev/null +++ b/hc/accounts/tests/test_login_tfa.py @@ -0,0 +1,103 @@ +from unittest.mock import patch + +from hc.test import BaseTestCase + + +class LoginTfaTestCase(BaseTestCase): + def setUp(self): + super().setUp() + + # This is the user we're trying to authenticate + session = self.client.session + session["2fa_user_id"] = self.alice.id + session.save() + + self.url = "/accounts/login/two_factor/" + self.checks_url = f"/projects/{self.project.code}/checks/" + + def test_it_shows_form(self): + r = self.client.get(self.url) + self.assertContains(r, "Waiting for security key") + + # It should put a "state" key in the session: + self.assertIn("state", self.client.session) + + @patch("hc.accounts.views._check_credential") + def test_it_logs_in(self, mock_check_credential): + mock_check_credential.return_value = True + + session = self.client.session + session["state"] = "dummy-state" + session.save() + + payload = { + "name": "My New Key", + "credential_id": "e30=", + "client_data_json": "e30=", + "authenticator_data": "e30=", + "signature": "e30=", + } + + r = self.client.post(self.url, payload) + self.assertRedirects(r, self.checks_url) + + self.assertNotIn("state", self.client.session) + self.assertNotIn("2fa_user_id", self.client.session) + + @patch("hc.accounts.views._check_credential") + def test_it_redirects_after_login(self, mock_check_credential): + mock_check_credential.return_value = True + + session = self.client.session + session["state"] = "dummy-state" + session.save() + + payload = { + "name": "My New Key", + "credential_id": "e30=", + "client_data_json": "e30=", + "authenticator_data": "e30=", + "signature": "e30=", + } + + url = self.url + "?next=" + self.channels_url + r = self.client.post(url, payload) + self.assertRedirects(r, self.channels_url) + + @patch("hc.accounts.views._check_credential") + def test_it_handles_bad_base64(self, mock_check_credential): + mock_check_credential.return_value = None + + session = self.client.session + session["state"] = "dummy-state" + session.save() + + payload = { + "name": "My New Key", + "credential_id": "this is not base64 data", + "client_data_json": "e30=", + "authenticator_data": "e30=", + "signature": "e30=", + } + + r = self.client.post(self.url, payload) + self.assertEqual(r.status_code, 400) + + @patch("hc.accounts.views._check_credential") + def test_it_handles_authentication_failure(self, mock_check_credential): + mock_check_credential.return_value = None + + session = self.client.session + session["state"] = "dummy-state" + session.save() + + payload = { + "name": "My New Key", + "credential_id": "e30=", + "client_data_json": "e30=", + "authenticator_data": "e30=", + "signature": "e30=", + } + + r = self.client.post(self.url, payload) + self.assertEqual(r.status_code, 400)