From 9fdc2faecd18dee7a9164171a75ac9c31f66d44c Mon Sep 17 00:00:00 2001 From: Livio Bieri Date: Mon, 4 Dec 2023 15:00:51 +0100 Subject: [PATCH] chore: adds test for sso flows --- .../course/creators/test_utils.py | 4 +- .../vbv_lernwelt/sso/tests/test_sso_flow.py | 338 ++++++++++++++++-- server/vbv_lernwelt/sso/views.py | 66 ++-- 3 files changed, 366 insertions(+), 42 deletions(-) diff --git a/server/vbv_lernwelt/course/creators/test_utils.py b/server/vbv_lernwelt/course/creators/test_utils.py index 44e97bff..9aca28d4 100644 --- a/server/vbv_lernwelt/course/creators/test_utils.py +++ b/server/vbv_lernwelt/course/creators/test_utils.py @@ -57,8 +57,8 @@ from vbv_lernwelt.learnpath.tests.learning_path_factories import ( ) -def create_course(title: str) -> Tuple[Course, CoursePage]: - course = Course.objects.create(title=title, category_name="Handlungsfeld") +def create_course(title: str, _id=None) -> Tuple[Course, CoursePage]: + course = Course.objects.create(id=_id, title=title, category_name="Handlungsfeld") course_page = CoursePageFactory( title="Test Lehrgang", diff --git a/server/vbv_lernwelt/sso/tests/test_sso_flow.py b/server/vbv_lernwelt/sso/tests/test_sso_flow.py index fd0b140f..425d8844 100644 --- a/server/vbv_lernwelt/sso/tests/test_sso_flow.py +++ b/server/vbv_lernwelt/sso/tests/test_sso_flow.py @@ -1,53 +1,351 @@ +import base64 +import json import uuid -from unittest.mock import Mock, patch +from unittest.mock import ANY, MagicMock, patch from authlib.integrations.base_client import OAuthError -from django.conf import settings -from django.test import TestCase +from django.shortcuts import redirect +from django.test import override_settings, TestCase from django.urls import reverse from vbv_lernwelt.core.models import User +from vbv_lernwelt.course.consts import COURSE_VERSICHERUNGSVERMITTLERIN_ID +from vbv_lernwelt.course.creators.test_utils import ( + add_course_session_user, + create_course, + create_course_session, +) +from vbv_lernwelt.course.models import CourseSession, CourseSessionUser +from vbv_lernwelt.importer.services import create_or_update_user def decoded_token(email, oid=None, given_name="Bobby", family_name="Table"): return { - "emails": [email], + "email": email, "oid": oid or uuid.uuid4(), "given_name": given_name, "family_name": family_name, } -class TestSSOFlow(TestCase): +class TestSignInAuthorizeSSO(TestCase): + def setUp(self): + CourseSession.objects.all().delete() + User.objects.all().delete() + + @override_settings(OAUTH={"client_name": "mock"}) @patch("vbv_lernwelt.sso.views.oauth") @patch("vbv_lernwelt.sso.views.decode_jwt") - def test_authorize_redirects_on_success(self, mock_decode_jwt, _): + def test_authorize_first_time_uk(self, mock_decode_jwt, mock_oauth): # GIVEN - email = "bobby@drop.table" - mock_decode_jwt.return_value = decoded_token(email) + email = "user@example.com" + token = decoded_token(email) + + mock_decode_jwt.return_value = token + mock_oauth.signin.authorize_access_token.return_value = { + "id_token": "" + } + + state = base64.urlsafe_b64encode( + json.dumps({"course": "uk", "next": "/shall-be-ignored"}).encode() + ).decode() # WHEN - response = self.client.get(reverse("sso:authorize")) + response = self.client.get(reverse("sso:authorize"), {"state": state}) # THEN - self.assertTrue(User.objects.filter(email=email).exists()) - self.assertEqual(response.status_code, 302) - self.assertEqual(response.url, "/") # noqa + self.assertEqual(302, response.status_code) + self.assertEqual("/onboarding/uk/account/create", response.url) # noqa + + user = User.objects.get(email=email) # noqa + self.assertIsNotNone(user) + + self.assertEqual(user.email, email) + self.assertEqual(user.first_name, token["given_name"]) + self.assertEqual(user.last_name, token["family_name"]) + self.assertEqual(user.sso_id, token["oid"]) + + @override_settings(OAUTH={"client_name": "mock"}) + @patch("vbv_lernwelt.sso.views.oauth") + @patch("vbv_lernwelt.sso.views.decode_jwt") + def test_authorize_first_time_vv(self, mock_decode_jwt, mock_oauth): + # GIVEN + email = "user@example.com" + token = decoded_token(email) + + mock_decode_jwt.return_value = token + mock_oauth.signin.authorize_access_token.return_value = { + "id_token": "" + } + + state = base64.urlsafe_b64encode( + json.dumps({"course": "vv", "next": "/shall-be-ignored"}).encode() + ).decode() + + # WHEN + response = self.client.get(reverse("sso:authorize"), {"state": state}) + + # THEN + self.assertEqual(302, response.status_code) + self.assertEqual("/onboarding/vv/account/create", response.url) # noqa + + user = User.objects.get(email=email) # noqa + self.assertIsNotNone(user) + + self.assertEqual(user.email, email) + self.assertEqual(user.first_name, token["given_name"]) + self.assertEqual(user.last_name, token["family_name"]) + self.assertEqual(user.sso_id, token["oid"]) @patch("vbv_lernwelt.sso.views.oauth") def test_authorize_on_tampered_token(self, mock_oauth_service): # GIVEN - client_name = settings.OAUTH["client_name"] - client_mock = Mock() - client_mock.authorize_access_token.side_effect = OAuthError() - setattr(mock_oauth_service, client_name, client_mock) + mock_oauth_service.signin.authorize_access_token.side_effect = OAuthError() # WHEN response = self.client.get(reverse("sso:authorize")) # THEN - # sanity check that the mock was called (-> setup is correct) - self.assertEqual(client_mock.authorize_access_token.call_count, 1) + self.assertEqual(302, response.status_code) + self.assertEqual("/", response.url) # noqa - self.assertEqual(response.status_code, 302) - self.assertEqual(response.url, "/login-error?state=someerror") # noqa + @override_settings(OAUTH={"client_name": "mock"}) + @patch("vbv_lernwelt.sso.views.oauth") + @patch("vbv_lernwelt.sso.views.decode_jwt") + def test_authorize_onboarded_uk(self, mock_decode_jwt, mock_oauth): + # GIVEN + email = "some@email.com" + token = decoded_token(email) + mock_decode_jwt.return_value = token + mock_oauth.signin.authorize_access_token.return_value = { + "id_token": "" + } + + # create a user that is already onboarded for UK + user = create_or_update_user( + email=email, + sso_id=str(token["oid"]), + first_name=token["given_name"], + last_name=token["family_name"], + ) + + course, _ = create_course("uk") + course_session = create_course_session(course, "UK", "2023") + add_course_session_user(course_session, user, CourseSessionUser.Role.MEMBER) + + self.assertIsNotNone(User.objects.get(email=email)) + + # WHEN + state = base64.urlsafe_b64encode(json.dumps({"course": "uk"}).encode()).decode() + response = self.client.get(reverse("sso:authorize"), {"state": state}) + + # THEN + self.assertEqual(302, response.status_code) + self.assertEqual("/", response.url) # noqa + + @override_settings(OAUTH={"client_name": "mock"}) + @patch("vbv_lernwelt.sso.views.oauth") + @patch("vbv_lernwelt.sso.views.decode_jwt") + def test_authorize_onboarded_vv(self, mock_decode_jwt, mock_oauth): + # GIVEN + email = "some@email.com" + token = decoded_token(email) + mock_decode_jwt.return_value = token + mock_oauth.signin.authorize_access_token.return_value = { + "id_token": "" + } + + # create a user that is already onboarded for UK + user = create_or_update_user( + email=email, + sso_id=str(token["oid"]), + first_name=token["given_name"], + last_name=token["family_name"], + ) + + course, _ = create_course(_id=COURSE_VERSICHERUNGSVERMITTLERIN_ID, title="VV") + course_session = create_course_session(course, "VV", "VV") + add_course_session_user(course_session, user, CourseSessionUser.Role.MEMBER) + + self.assertIsNotNone(User.objects.get(email=email)) + + # WHEN + state = base64.urlsafe_b64encode(json.dumps({"course": "vv"}).encode()).decode() + response = self.client.get(reverse("sso:authorize"), {"state": state}) + + # THEN + self.assertEqual(302, response.status_code) + self.assertEqual("/", response.url) # noqa + + @override_settings(OAUTH={"client_name": "mock"}) + @patch("vbv_lernwelt.sso.views.oauth") + @patch("vbv_lernwelt.sso.views.decode_jwt") + def test_authorize_next_url(self, mock_decode_jwt, mock_oauth): + # GIVEN + next_url = "/some/next/url" + + mock_decode_jwt.return_value = decoded_token("whatever@example.com") + mock_oauth.signin.authorize_access_token.return_value = { + "id_token": "" + } + + # WHEN + state = base64.urlsafe_b64encode( + json.dumps({"next": next_url}).encode() + ).decode() + + response = self.client.get(reverse("sso:authorize"), {"state": state}) + + # THEN + self.assertEqual(302, response.status_code) + self.assertEqual(next_url, response.url) # noqa + + +class TestSignIn(TestCase): + @override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback") + @patch("vbv_lernwelt.sso.views.oauth") + def test_signin_with_course_param(self, mock_oauth): + # GIVEN + course_param = "vv" + + expected_state = {"course": course_param, "next": None} + expected_state_encoded = base64.urlsafe_b64encode( + json.dumps(expected_state).encode() + ).decode() + + mock_oauth.signin.authorize_redirect = MagicMock() + mock_oauth.signin.authorize_redirect.return_value = redirect( + "/just/here/to/return/a/redirect/object" + ) + + # WHEN + self.client.get( + reverse("sso:login"), + {"course": course_param}, + ) + + # THEN + mock_oauth.signin.authorize_redirect.assert_called_once_with( + ANY, + "/sso/callback", + state=expected_state_encoded, + lang="de", + ) + + @override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback") + @patch("vbv_lernwelt.sso.views.oauth") + def test_signin_with_state_param(self, mock_oauth): + # GIVEN + state_param = "vv" + + expected_state = {"course": state_param, "next": None} + expected_state_encoded = base64.urlsafe_b64encode( + json.dumps(expected_state).encode() + ).decode() + + mock_oauth.signin.authorize_redirect = MagicMock() + mock_oauth.signin.authorize_redirect.return_value = redirect( + "/just/here/to/return/a/redirect/object" + ) + + # WHEN + self.client.get( + reverse("sso:login"), + {"state": state_param}, + ) + + # THEN + mock_oauth.signin.authorize_redirect.assert_called_once_with( + ANY, + "/sso/callback", + state=expected_state_encoded, + lang="de", + ) + + @override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback") + @patch("vbv_lernwelt.sso.views.oauth") + def test_signin_next_url(self, mock_oauth): + # GIVEN + next_url = "/some/next/url" + + expected_state = {"course": None, "next": next_url} + expected_state_encoded = base64.urlsafe_b64encode( + json.dumps(expected_state).encode() + ).decode() + + mock_oauth.signin.authorize_redirect = MagicMock() + mock_oauth.signin.authorize_redirect.return_value = redirect( + "/just/here/to/return/a/redirect/object" + ) + + # WHEN + self.client.get( + reverse("sso:login"), + {"next": next_url}, + ) + + # THEN + mock_oauth.signin.authorize_redirect.assert_called_once_with( + ANY, + "/sso/callback", + state=expected_state_encoded, + lang=ANY, + ) + + @override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback") + @patch("vbv_lernwelt.sso.views.oauth") + def test_signin_language(self, mock_oauth): + # GIVEN + language = "fr" + + mock_oauth.signin.authorize_redirect = MagicMock() + mock_oauth.signin.authorize_redirect.return_value = redirect( + "/just/here/to/return/a/redirect/object" + ) + + # WHEN + self.client.get( + reverse("sso:login"), + {"lang": language}, + ) + + # THEN + mock_oauth.signin.authorize_redirect.assert_called_once_with( + ANY, + "/sso/callback", + state=ANY, + lang=language, + ) + + +class TestSignUp(TestCase): + @override_settings(OAUTH_SIGNUP_REDIRECT_URI="/sso/login") + @patch("vbv_lernwelt.sso.views.oauth") + def test_signup_with_course_param(self, mock_oauth): + # GIVEN + course_param = "vv" + language = "fr" + + expected_state = {"course": course_param, "next": None} + expected_state_encoded = base64.urlsafe_b64encode( + json.dumps(expected_state).encode() + ).decode() + + mock_oauth.signup.authorize_redirect = MagicMock() + mock_oauth.signup.authorize_redirect.return_value = redirect( + "/just/here/to/return/a/redirect/object" + ) + + # WHEN + self.client.get( + reverse("sso:signup"), + {"course": course_param, "lang": language}, + ) + + # THEN + mock_oauth.signup.authorize_redirect.assert_called_once_with( + ANY, + "/sso/login", + state=course_param, + lang=language, + ) diff --git a/server/vbv_lernwelt/sso/views.py b/server/vbv_lernwelt/sso/views.py index adb47d82..dc15459c 100644 --- a/server/vbv_lernwelt/sso/views.py +++ b/server/vbv_lernwelt/sso/views.py @@ -1,3 +1,6 @@ +import base64 +import json + import structlog as structlog from authlib.integrations.base_client import OAuthError from django.conf import settings @@ -5,6 +8,7 @@ from django.contrib.auth import login as dj_login from django.shortcuts import redirect from sentry_sdk import capture_exception +from vbv_lernwelt.core.models import User from vbv_lernwelt.course.models import CourseSession from vbv_lernwelt.course_session.utils import has_course_session_user_vv from vbv_lernwelt.importer.services import create_or_update_user @@ -13,24 +17,49 @@ from vbv_lernwelt.sso.jwt import decode_jwt logger = structlog.get_logger(__name__) -OAUTH_FAIL_REDIRECT = "login-error" - def signup(request): course = request.GET.get("course") + redirect_uri = settings.OAUTH_SIGNUP_REDIRECT_URI logger.debug(f"SSO Signup (course={course})", sso_signup_redirect_uri=redirect_uri) - return oauth.signup.authorize_redirect(request, redirect_uri, state=course) + + return oauth.signup.authorize_redirect( + request, redirect_uri, state=course, lang=request.GET.get("lang", "de") + ) def signin(request): # course query OR state when coming from signup (oauth) - course = request.GET.get("course", request.GET.get("state")) + course_param = request.GET.get("course", request.GET.get("state")) + next_param = request.GET.get("next") + + state_json = json.dumps({"course": course_param, "next": next_param}) + state_encoded = base64.urlsafe_b64encode(state_json.encode()).decode() + redirect_uri = settings.OAUTH_SIGNIN_REDIRECT_URI - logger.info(f"SSO Login (course={course})", sso_login_redirect_uri=redirect_uri) - return oauth.signin.authorize_redirect( - request, redirect_uri, state=course, lang=request.GET.get("lang", "de") + + logger.info( + f"SSO Login (course={course_param}, next={next_param})", + sso_login_redirect_uri=redirect_uri, ) + return oauth.signin.authorize_redirect( + request, redirect_uri, state=state_encoded, lang=request.GET.get("lang", "de") + ) + + +def get_redirect_uri(user: User, course: str, next_url: str): + if course == "vv" and not has_course_session_user_vv(user): + return redirect("/onboarding/vv/account/create") + elif ( + course == "uk" + and not CourseSession.objects.filter(coursesessionuser__user=user).exists() + ): + return redirect("/onboarding/uk/account/create") + elif next_url: + return redirect(next_url) + + return redirect("/") def authorize_signin(request): @@ -40,13 +69,19 @@ def authorize_signin(request): logger.error(e, exc_info=True, label="sso") if not settings.DEBUG: capture_exception(e) - return redirect(f"/{OAUTH_FAIL_REDIRECT}?state=oautherror") + return redirect("/") id_token = decode_jwt(jwt_token["id_token"]) - course = request.GET.get("state") + + state = json.loads( + base64.urlsafe_b64decode(request.GET.get("state").encode()).decode() + ) + + course = state.get("course") + next_url = state.get("next") logger.debug( - f"SSO Authorize (course={course})", + f"SSO Authorize (course={course}, next={next_url}", sso_authorize_id_token=id_token, ) @@ -59,13 +94,4 @@ def authorize_signin(request): dj_login(request, user) - # figure out where to redirect to (onboarding or home) - if course == "vv" and not has_course_session_user_vv(user): - return redirect("/onboarding/vv/account/create") - elif ( - course == "uk" - and not CourseSession.objects.filter(coursesessionuser__user=user).exists() - ): - return redirect("/onboarding/uk/account/create") - else: - return redirect("/") + return get_redirect_uri(user=user, course=course, next_url=next_url)