chore: adds test for sso flows

This commit is contained in:
Livio Bieri 2023-12-04 15:00:51 +01:00 committed by Christian Cueni
parent 6f0f4551bc
commit 9fdc2faecd
3 changed files with 366 additions and 42 deletions

View File

@ -57,8 +57,8 @@ from vbv_lernwelt.learnpath.tests.learning_path_factories import (
)
def create_course(title: str) -> Tuple[Course, CoursePage]:
course = Course.objects.create(title=title, category_name="Handlungsfeld")
def create_course(title: str, _id=None) -> Tuple[Course, CoursePage]:
course = Course.objects.create(id=_id, title=title, category_name="Handlungsfeld")
course_page = CoursePageFactory(
title="Test Lehrgang",

View File

@ -1,53 +1,351 @@
import base64
import json
import uuid
from unittest.mock import Mock, patch
from unittest.mock import ANY, MagicMock, patch
from authlib.integrations.base_client import OAuthError
from django.conf import settings
from django.test import TestCase
from django.shortcuts import redirect
from django.test import override_settings, TestCase
from django.urls import reverse
from vbv_lernwelt.core.models import User
from vbv_lernwelt.course.consts import COURSE_VERSICHERUNGSVERMITTLERIN_ID
from vbv_lernwelt.course.creators.test_utils import (
add_course_session_user,
create_course,
create_course_session,
)
from vbv_lernwelt.course.models import CourseSession, CourseSessionUser
from vbv_lernwelt.importer.services import create_or_update_user
def decoded_token(email, oid=None, given_name="Bobby", family_name="Table"):
return {
"emails": [email],
"email": email,
"oid": oid or uuid.uuid4(),
"given_name": given_name,
"family_name": family_name,
}
class TestSSOFlow(TestCase):
class TestSignInAuthorizeSSO(TestCase):
def setUp(self):
CourseSession.objects.all().delete()
User.objects.all().delete()
@override_settings(OAUTH={"client_name": "mock"})
@patch("vbv_lernwelt.sso.views.oauth")
@patch("vbv_lernwelt.sso.views.decode_jwt")
def test_authorize_redirects_on_success(self, mock_decode_jwt, _):
def test_authorize_first_time_uk(self, mock_decode_jwt, mock_oauth):
# GIVEN
email = "bobby@drop.table"
mock_decode_jwt.return_value = decoded_token(email)
email = "user@example.com"
token = decoded_token(email)
mock_decode_jwt.return_value = token
mock_oauth.signin.authorize_access_token.return_value = {
"id_token": "<some-token-irrelevant-for-this-test>"
}
state = base64.urlsafe_b64encode(
json.dumps({"course": "uk", "next": "/shall-be-ignored"}).encode()
).decode()
# WHEN
response = self.client.get(reverse("sso:authorize"))
response = self.client.get(reverse("sso:authorize"), {"state": state})
# THEN
self.assertTrue(User.objects.filter(email=email).exists())
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/") # noqa
self.assertEqual(302, response.status_code)
self.assertEqual("/onboarding/uk/account/create", response.url) # noqa
user = User.objects.get(email=email) # noqa
self.assertIsNotNone(user)
self.assertEqual(user.email, email)
self.assertEqual(user.first_name, token["given_name"])
self.assertEqual(user.last_name, token["family_name"])
self.assertEqual(user.sso_id, token["oid"])
@override_settings(OAUTH={"client_name": "mock"})
@patch("vbv_lernwelt.sso.views.oauth")
@patch("vbv_lernwelt.sso.views.decode_jwt")
def test_authorize_first_time_vv(self, mock_decode_jwt, mock_oauth):
# GIVEN
email = "user@example.com"
token = decoded_token(email)
mock_decode_jwt.return_value = token
mock_oauth.signin.authorize_access_token.return_value = {
"id_token": "<some-token-irrelevant-for-this-test>"
}
state = base64.urlsafe_b64encode(
json.dumps({"course": "vv", "next": "/shall-be-ignored"}).encode()
).decode()
# WHEN
response = self.client.get(reverse("sso:authorize"), {"state": state})
# THEN
self.assertEqual(302, response.status_code)
self.assertEqual("/onboarding/vv/account/create", response.url) # noqa
user = User.objects.get(email=email) # noqa
self.assertIsNotNone(user)
self.assertEqual(user.email, email)
self.assertEqual(user.first_name, token["given_name"])
self.assertEqual(user.last_name, token["family_name"])
self.assertEqual(user.sso_id, token["oid"])
@patch("vbv_lernwelt.sso.views.oauth")
def test_authorize_on_tampered_token(self, mock_oauth_service):
# GIVEN
client_name = settings.OAUTH["client_name"]
client_mock = Mock()
client_mock.authorize_access_token.side_effect = OAuthError()
setattr(mock_oauth_service, client_name, client_mock)
mock_oauth_service.signin.authorize_access_token.side_effect = OAuthError()
# WHEN
response = self.client.get(reverse("sso:authorize"))
# THEN
# sanity check that the mock was called (-> setup is correct)
self.assertEqual(client_mock.authorize_access_token.call_count, 1)
self.assertEqual(302, response.status_code)
self.assertEqual("/", response.url) # noqa
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/login-error?state=someerror") # noqa
@override_settings(OAUTH={"client_name": "mock"})
@patch("vbv_lernwelt.sso.views.oauth")
@patch("vbv_lernwelt.sso.views.decode_jwt")
def test_authorize_onboarded_uk(self, mock_decode_jwt, mock_oauth):
# GIVEN
email = "some@email.com"
token = decoded_token(email)
mock_decode_jwt.return_value = token
mock_oauth.signin.authorize_access_token.return_value = {
"id_token": "<some-token-irrelevant-for-this-test>"
}
# create a user that is already onboarded for UK
user = create_or_update_user(
email=email,
sso_id=str(token["oid"]),
first_name=token["given_name"],
last_name=token["family_name"],
)
course, _ = create_course("uk")
course_session = create_course_session(course, "UK", "2023")
add_course_session_user(course_session, user, CourseSessionUser.Role.MEMBER)
self.assertIsNotNone(User.objects.get(email=email))
# WHEN
state = base64.urlsafe_b64encode(json.dumps({"course": "uk"}).encode()).decode()
response = self.client.get(reverse("sso:authorize"), {"state": state})
# THEN
self.assertEqual(302, response.status_code)
self.assertEqual("/", response.url) # noqa
@override_settings(OAUTH={"client_name": "mock"})
@patch("vbv_lernwelt.sso.views.oauth")
@patch("vbv_lernwelt.sso.views.decode_jwt")
def test_authorize_onboarded_vv(self, mock_decode_jwt, mock_oauth):
# GIVEN
email = "some@email.com"
token = decoded_token(email)
mock_decode_jwt.return_value = token
mock_oauth.signin.authorize_access_token.return_value = {
"id_token": "<some-token-irrelevant-for-this-test>"
}
# create a user that is already onboarded for UK
user = create_or_update_user(
email=email,
sso_id=str(token["oid"]),
first_name=token["given_name"],
last_name=token["family_name"],
)
course, _ = create_course(_id=COURSE_VERSICHERUNGSVERMITTLERIN_ID, title="VV")
course_session = create_course_session(course, "VV", "VV")
add_course_session_user(course_session, user, CourseSessionUser.Role.MEMBER)
self.assertIsNotNone(User.objects.get(email=email))
# WHEN
state = base64.urlsafe_b64encode(json.dumps({"course": "vv"}).encode()).decode()
response = self.client.get(reverse("sso:authorize"), {"state": state})
# THEN
self.assertEqual(302, response.status_code)
self.assertEqual("/", response.url) # noqa
@override_settings(OAUTH={"client_name": "mock"})
@patch("vbv_lernwelt.sso.views.oauth")
@patch("vbv_lernwelt.sso.views.decode_jwt")
def test_authorize_next_url(self, mock_decode_jwt, mock_oauth):
# GIVEN
next_url = "/some/next/url"
mock_decode_jwt.return_value = decoded_token("whatever@example.com")
mock_oauth.signin.authorize_access_token.return_value = {
"id_token": "<some-token-irrelevant-for-this-test>"
}
# WHEN
state = base64.urlsafe_b64encode(
json.dumps({"next": next_url}).encode()
).decode()
response = self.client.get(reverse("sso:authorize"), {"state": state})
# THEN
self.assertEqual(302, response.status_code)
self.assertEqual(next_url, response.url) # noqa
class TestSignIn(TestCase):
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
@patch("vbv_lernwelt.sso.views.oauth")
def test_signin_with_course_param(self, mock_oauth):
# GIVEN
course_param = "vv"
expected_state = {"course": course_param, "next": None}
expected_state_encoded = base64.urlsafe_b64encode(
json.dumps(expected_state).encode()
).decode()
mock_oauth.signin.authorize_redirect = MagicMock()
mock_oauth.signin.authorize_redirect.return_value = redirect(
"/just/here/to/return/a/redirect/object"
)
# WHEN
self.client.get(
reverse("sso:login"),
{"course": course_param},
)
# THEN
mock_oauth.signin.authorize_redirect.assert_called_once_with(
ANY,
"/sso/callback",
state=expected_state_encoded,
lang="de",
)
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
@patch("vbv_lernwelt.sso.views.oauth")
def test_signin_with_state_param(self, mock_oauth):
# GIVEN
state_param = "vv"
expected_state = {"course": state_param, "next": None}
expected_state_encoded = base64.urlsafe_b64encode(
json.dumps(expected_state).encode()
).decode()
mock_oauth.signin.authorize_redirect = MagicMock()
mock_oauth.signin.authorize_redirect.return_value = redirect(
"/just/here/to/return/a/redirect/object"
)
# WHEN
self.client.get(
reverse("sso:login"),
{"state": state_param},
)
# THEN
mock_oauth.signin.authorize_redirect.assert_called_once_with(
ANY,
"/sso/callback",
state=expected_state_encoded,
lang="de",
)
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
@patch("vbv_lernwelt.sso.views.oauth")
def test_signin_next_url(self, mock_oauth):
# GIVEN
next_url = "/some/next/url"
expected_state = {"course": None, "next": next_url}
expected_state_encoded = base64.urlsafe_b64encode(
json.dumps(expected_state).encode()
).decode()
mock_oauth.signin.authorize_redirect = MagicMock()
mock_oauth.signin.authorize_redirect.return_value = redirect(
"/just/here/to/return/a/redirect/object"
)
# WHEN
self.client.get(
reverse("sso:login"),
{"next": next_url},
)
# THEN
mock_oauth.signin.authorize_redirect.assert_called_once_with(
ANY,
"/sso/callback",
state=expected_state_encoded,
lang=ANY,
)
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
@patch("vbv_lernwelt.sso.views.oauth")
def test_signin_language(self, mock_oauth):
# GIVEN
language = "fr"
mock_oauth.signin.authorize_redirect = MagicMock()
mock_oauth.signin.authorize_redirect.return_value = redirect(
"/just/here/to/return/a/redirect/object"
)
# WHEN
self.client.get(
reverse("sso:login"),
{"lang": language},
)
# THEN
mock_oauth.signin.authorize_redirect.assert_called_once_with(
ANY,
"/sso/callback",
state=ANY,
lang=language,
)
class TestSignUp(TestCase):
@override_settings(OAUTH_SIGNUP_REDIRECT_URI="/sso/login")
@patch("vbv_lernwelt.sso.views.oauth")
def test_signup_with_course_param(self, mock_oauth):
# GIVEN
course_param = "vv"
language = "fr"
expected_state = {"course": course_param, "next": None}
expected_state_encoded = base64.urlsafe_b64encode(
json.dumps(expected_state).encode()
).decode()
mock_oauth.signup.authorize_redirect = MagicMock()
mock_oauth.signup.authorize_redirect.return_value = redirect(
"/just/here/to/return/a/redirect/object"
)
# WHEN
self.client.get(
reverse("sso:signup"),
{"course": course_param, "lang": language},
)
# THEN
mock_oauth.signup.authorize_redirect.assert_called_once_with(
ANY,
"/sso/login",
state=course_param,
lang=language,
)

View File

@ -1,3 +1,6 @@
import base64
import json
import structlog as structlog
from authlib.integrations.base_client import OAuthError
from django.conf import settings
@ -5,6 +8,7 @@ from django.contrib.auth import login as dj_login
from django.shortcuts import redirect
from sentry_sdk import capture_exception
from vbv_lernwelt.core.models import User
from vbv_lernwelt.course.models import CourseSession
from vbv_lernwelt.course_session.utils import has_course_session_user_vv
from vbv_lernwelt.importer.services import create_or_update_user
@ -13,24 +17,49 @@ from vbv_lernwelt.sso.jwt import decode_jwt
logger = structlog.get_logger(__name__)
OAUTH_FAIL_REDIRECT = "login-error"
def signup(request):
course = request.GET.get("course")
redirect_uri = settings.OAUTH_SIGNUP_REDIRECT_URI
logger.debug(f"SSO Signup (course={course})", sso_signup_redirect_uri=redirect_uri)
return oauth.signup.authorize_redirect(request, redirect_uri, state=course)
return oauth.signup.authorize_redirect(
request, redirect_uri, state=course, lang=request.GET.get("lang", "de")
)
def signin(request):
# course query OR state when coming from signup (oauth)
course = request.GET.get("course", request.GET.get("state"))
course_param = request.GET.get("course", request.GET.get("state"))
next_param = request.GET.get("next")
state_json = json.dumps({"course": course_param, "next": next_param})
state_encoded = base64.urlsafe_b64encode(state_json.encode()).decode()
redirect_uri = settings.OAUTH_SIGNIN_REDIRECT_URI
logger.info(f"SSO Login (course={course})", sso_login_redirect_uri=redirect_uri)
return oauth.signin.authorize_redirect(
request, redirect_uri, state=course, lang=request.GET.get("lang", "de")
logger.info(
f"SSO Login (course={course_param}, next={next_param})",
sso_login_redirect_uri=redirect_uri,
)
return oauth.signin.authorize_redirect(
request, redirect_uri, state=state_encoded, lang=request.GET.get("lang", "de")
)
def get_redirect_uri(user: User, course: str, next_url: str):
if course == "vv" and not has_course_session_user_vv(user):
return redirect("/onboarding/vv/account/create")
elif (
course == "uk"
and not CourseSession.objects.filter(coursesessionuser__user=user).exists()
):
return redirect("/onboarding/uk/account/create")
elif next_url:
return redirect(next_url)
return redirect("/")
def authorize_signin(request):
@ -40,13 +69,19 @@ def authorize_signin(request):
logger.error(e, exc_info=True, label="sso")
if not settings.DEBUG:
capture_exception(e)
return redirect(f"/{OAUTH_FAIL_REDIRECT}?state=oautherror")
return redirect("/")
id_token = decode_jwt(jwt_token["id_token"])
course = request.GET.get("state")
state = json.loads(
base64.urlsafe_b64decode(request.GET.get("state").encode()).decode()
)
course = state.get("course")
next_url = state.get("next")
logger.debug(
f"SSO Authorize (course={course})",
f"SSO Authorize (course={course}, next={next_url}",
sso_authorize_id_token=id_token,
)
@ -59,13 +94,4 @@ def authorize_signin(request):
dj_login(request, user)
# figure out where to redirect to (onboarding or home)
if course == "vv" and not has_course_session_user_vv(user):
return redirect("/onboarding/vv/account/create")
elif (
course == "uk"
and not CourseSession.objects.filter(coursesessionuser__user=user).exists()
):
return redirect("/onboarding/uk/account/create")
else:
return redirect("/")
return get_redirect_uri(user=user, course=course, next_url=next_url)