diff --git a/hc/accounts/tests/test_check_token.py b/hc/accounts/tests/test_check_token.py index 0f6750dd..9fab1860 100644 --- a/hc/accounts/tests/test_check_token.py +++ b/hc/accounts/tests/test_check_token.py @@ -1,15 +1,16 @@ from django.contrib.auth.models import User +from django.core.urlresolvers import reverse from django.test import TestCase class CheckTokenTestCase(TestCase): def setUp(self): - super(CheckTokenTestCase, self).setUp() + super(CheckTokenTestCase, self).setUp() - self.alice = User(username="alice") - self.alice.set_password("secret-token") - self.alice.save() + self.alice = User(username="alice") + self.alice.set_password("secret-token") + self.alice.save() def test_it_redirects(self): r = self.client.get("/accounts/check_token/alice/secret-token/") @@ -26,3 +27,10 @@ class CheckTokenTestCase(TestCase): # Login again, when already authenticated r = self.client.get("/accounts/check_token/alice/secret-token/") assert r.status_code == 302 + + def test_it_redirects_bad_login(self): + # Login with a bad token + r = self.client.get("/accounts/check_token/alice/invalid-token/") + assert r.status_code == 302 + assert r.url.endswith(reverse("hc-login")) + assert self.client.session["bad_link"] diff --git a/hc/accounts/tests/test_login.py b/hc/accounts/tests/test_login.py index 3dadf58a..fba8ca86 100644 --- a/hc/accounts/tests/test_login.py +++ b/hc/accounts/tests/test_login.py @@ -29,3 +29,8 @@ class LoginTestCase(TestCase): # And check should be associated with the new user check_again = Check.objects.get(code=check.code) assert check_again.user + + def test_it_pops_bad_link_from_session(self): + self.client.session["bad_link"] = True + self.client.get("/accounts/login/") + assert "bad_link" not in self.client.session diff --git a/hc/accounts/views.py b/hc/accounts/views.py index 12ea0bcd..a92b1833 100644 --- a/hc/accounts/views.py +++ b/hc/accounts/views.py @@ -80,7 +80,8 @@ def login(request): else: form = EmailForm() - ctx = {"form": form} + bad_link = request.session.pop("bad_link", None) + ctx = {"form": form, "bad_link": bad_link} return render(request, "accounts/login.html", ctx) @@ -110,8 +111,8 @@ def check_token(request, username, token): return redirect("hc-checks") - ctx = {"bad_link": True} - return render(request, "accounts/login.html", ctx) + request.session["bad_link"] = True + return redirect("hc-login") @login_required