chore: adds test for sso flows
This commit is contained in:
parent
6f0f4551bc
commit
9fdc2faecd
|
|
@ -57,8 +57,8 @@ from vbv_lernwelt.learnpath.tests.learning_path_factories import (
|
|||
)
|
||||
|
||||
|
||||
def create_course(title: str) -> Tuple[Course, CoursePage]:
|
||||
course = Course.objects.create(title=title, category_name="Handlungsfeld")
|
||||
def create_course(title: str, _id=None) -> Tuple[Course, CoursePage]:
|
||||
course = Course.objects.create(id=_id, title=title, category_name="Handlungsfeld")
|
||||
|
||||
course_page = CoursePageFactory(
|
||||
title="Test Lehrgang",
|
||||
|
|
|
|||
|
|
@ -1,53 +1,351 @@
|
|||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
from authlib.integrations.base_client import OAuthError
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.shortcuts import redirect
|
||||
from django.test import override_settings, TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
from vbv_lernwelt.core.models import User
|
||||
from vbv_lernwelt.course.consts import COURSE_VERSICHERUNGSVERMITTLERIN_ID
|
||||
from vbv_lernwelt.course.creators.test_utils import (
|
||||
add_course_session_user,
|
||||
create_course,
|
||||
create_course_session,
|
||||
)
|
||||
from vbv_lernwelt.course.models import CourseSession, CourseSessionUser
|
||||
from vbv_lernwelt.importer.services import create_or_update_user
|
||||
|
||||
|
||||
def decoded_token(email, oid=None, given_name="Bobby", family_name="Table"):
|
||||
return {
|
||||
"emails": [email],
|
||||
"email": email,
|
||||
"oid": oid or uuid.uuid4(),
|
||||
"given_name": given_name,
|
||||
"family_name": family_name,
|
||||
}
|
||||
|
||||
|
||||
class TestSSOFlow(TestCase):
|
||||
class TestSignInAuthorizeSSO(TestCase):
|
||||
def setUp(self):
|
||||
CourseSession.objects.all().delete()
|
||||
User.objects.all().delete()
|
||||
|
||||
@override_settings(OAUTH={"client_name": "mock"})
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
@patch("vbv_lernwelt.sso.views.decode_jwt")
|
||||
def test_authorize_redirects_on_success(self, mock_decode_jwt, _):
|
||||
def test_authorize_first_time_uk(self, mock_decode_jwt, mock_oauth):
|
||||
# GIVEN
|
||||
email = "bobby@drop.table"
|
||||
mock_decode_jwt.return_value = decoded_token(email)
|
||||
email = "user@example.com"
|
||||
token = decoded_token(email)
|
||||
|
||||
mock_decode_jwt.return_value = token
|
||||
mock_oauth.signin.authorize_access_token.return_value = {
|
||||
"id_token": "<some-token-irrelevant-for-this-test>"
|
||||
}
|
||||
|
||||
state = base64.urlsafe_b64encode(
|
||||
json.dumps({"course": "uk", "next": "/shall-be-ignored"}).encode()
|
||||
).decode()
|
||||
|
||||
# WHEN
|
||||
response = self.client.get(reverse("sso:authorize"))
|
||||
response = self.client.get(reverse("sso:authorize"), {"state": state})
|
||||
|
||||
# THEN
|
||||
self.assertTrue(User.objects.filter(email=email).exists())
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, "/") # noqa
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual("/onboarding/uk/account/create", response.url) # noqa
|
||||
|
||||
user = User.objects.get(email=email) # noqa
|
||||
self.assertIsNotNone(user)
|
||||
|
||||
self.assertEqual(user.email, email)
|
||||
self.assertEqual(user.first_name, token["given_name"])
|
||||
self.assertEqual(user.last_name, token["family_name"])
|
||||
self.assertEqual(user.sso_id, token["oid"])
|
||||
|
||||
@override_settings(OAUTH={"client_name": "mock"})
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
@patch("vbv_lernwelt.sso.views.decode_jwt")
|
||||
def test_authorize_first_time_vv(self, mock_decode_jwt, mock_oauth):
|
||||
# GIVEN
|
||||
email = "user@example.com"
|
||||
token = decoded_token(email)
|
||||
|
||||
mock_decode_jwt.return_value = token
|
||||
mock_oauth.signin.authorize_access_token.return_value = {
|
||||
"id_token": "<some-token-irrelevant-for-this-test>"
|
||||
}
|
||||
|
||||
state = base64.urlsafe_b64encode(
|
||||
json.dumps({"course": "vv", "next": "/shall-be-ignored"}).encode()
|
||||
).decode()
|
||||
|
||||
# WHEN
|
||||
response = self.client.get(reverse("sso:authorize"), {"state": state})
|
||||
|
||||
# THEN
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual("/onboarding/vv/account/create", response.url) # noqa
|
||||
|
||||
user = User.objects.get(email=email) # noqa
|
||||
self.assertIsNotNone(user)
|
||||
|
||||
self.assertEqual(user.email, email)
|
||||
self.assertEqual(user.first_name, token["given_name"])
|
||||
self.assertEqual(user.last_name, token["family_name"])
|
||||
self.assertEqual(user.sso_id, token["oid"])
|
||||
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
def test_authorize_on_tampered_token(self, mock_oauth_service):
|
||||
# GIVEN
|
||||
client_name = settings.OAUTH["client_name"]
|
||||
client_mock = Mock()
|
||||
client_mock.authorize_access_token.side_effect = OAuthError()
|
||||
setattr(mock_oauth_service, client_name, client_mock)
|
||||
mock_oauth_service.signin.authorize_access_token.side_effect = OAuthError()
|
||||
|
||||
# WHEN
|
||||
response = self.client.get(reverse("sso:authorize"))
|
||||
|
||||
# THEN
|
||||
# sanity check that the mock was called (-> setup is correct)
|
||||
self.assertEqual(client_mock.authorize_access_token.call_count, 1)
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual("/", response.url) # noqa
|
||||
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, "/login-error?state=someerror") # noqa
|
||||
@override_settings(OAUTH={"client_name": "mock"})
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
@patch("vbv_lernwelt.sso.views.decode_jwt")
|
||||
def test_authorize_onboarded_uk(self, mock_decode_jwt, mock_oauth):
|
||||
# GIVEN
|
||||
email = "some@email.com"
|
||||
token = decoded_token(email)
|
||||
mock_decode_jwt.return_value = token
|
||||
mock_oauth.signin.authorize_access_token.return_value = {
|
||||
"id_token": "<some-token-irrelevant-for-this-test>"
|
||||
}
|
||||
|
||||
# create a user that is already onboarded for UK
|
||||
user = create_or_update_user(
|
||||
email=email,
|
||||
sso_id=str(token["oid"]),
|
||||
first_name=token["given_name"],
|
||||
last_name=token["family_name"],
|
||||
)
|
||||
|
||||
course, _ = create_course("uk")
|
||||
course_session = create_course_session(course, "UK", "2023")
|
||||
add_course_session_user(course_session, user, CourseSessionUser.Role.MEMBER)
|
||||
|
||||
self.assertIsNotNone(User.objects.get(email=email))
|
||||
|
||||
# WHEN
|
||||
state = base64.urlsafe_b64encode(json.dumps({"course": "uk"}).encode()).decode()
|
||||
response = self.client.get(reverse("sso:authorize"), {"state": state})
|
||||
|
||||
# THEN
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual("/", response.url) # noqa
|
||||
|
||||
@override_settings(OAUTH={"client_name": "mock"})
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
@patch("vbv_lernwelt.sso.views.decode_jwt")
|
||||
def test_authorize_onboarded_vv(self, mock_decode_jwt, mock_oauth):
|
||||
# GIVEN
|
||||
email = "some@email.com"
|
||||
token = decoded_token(email)
|
||||
mock_decode_jwt.return_value = token
|
||||
mock_oauth.signin.authorize_access_token.return_value = {
|
||||
"id_token": "<some-token-irrelevant-for-this-test>"
|
||||
}
|
||||
|
||||
# create a user that is already onboarded for UK
|
||||
user = create_or_update_user(
|
||||
email=email,
|
||||
sso_id=str(token["oid"]),
|
||||
first_name=token["given_name"],
|
||||
last_name=token["family_name"],
|
||||
)
|
||||
|
||||
course, _ = create_course(_id=COURSE_VERSICHERUNGSVERMITTLERIN_ID, title="VV")
|
||||
course_session = create_course_session(course, "VV", "VV")
|
||||
add_course_session_user(course_session, user, CourseSessionUser.Role.MEMBER)
|
||||
|
||||
self.assertIsNotNone(User.objects.get(email=email))
|
||||
|
||||
# WHEN
|
||||
state = base64.urlsafe_b64encode(json.dumps({"course": "vv"}).encode()).decode()
|
||||
response = self.client.get(reverse("sso:authorize"), {"state": state})
|
||||
|
||||
# THEN
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual("/", response.url) # noqa
|
||||
|
||||
@override_settings(OAUTH={"client_name": "mock"})
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
@patch("vbv_lernwelt.sso.views.decode_jwt")
|
||||
def test_authorize_next_url(self, mock_decode_jwt, mock_oauth):
|
||||
# GIVEN
|
||||
next_url = "/some/next/url"
|
||||
|
||||
mock_decode_jwt.return_value = decoded_token("whatever@example.com")
|
||||
mock_oauth.signin.authorize_access_token.return_value = {
|
||||
"id_token": "<some-token-irrelevant-for-this-test>"
|
||||
}
|
||||
|
||||
# WHEN
|
||||
state = base64.urlsafe_b64encode(
|
||||
json.dumps({"next": next_url}).encode()
|
||||
).decode()
|
||||
|
||||
response = self.client.get(reverse("sso:authorize"), {"state": state})
|
||||
|
||||
# THEN
|
||||
self.assertEqual(302, response.status_code)
|
||||
self.assertEqual(next_url, response.url) # noqa
|
||||
|
||||
|
||||
class TestSignIn(TestCase):
|
||||
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
def test_signin_with_course_param(self, mock_oauth):
|
||||
# GIVEN
|
||||
course_param = "vv"
|
||||
|
||||
expected_state = {"course": course_param, "next": None}
|
||||
expected_state_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(expected_state).encode()
|
||||
).decode()
|
||||
|
||||
mock_oauth.signin.authorize_redirect = MagicMock()
|
||||
mock_oauth.signin.authorize_redirect.return_value = redirect(
|
||||
"/just/here/to/return/a/redirect/object"
|
||||
)
|
||||
|
||||
# WHEN
|
||||
self.client.get(
|
||||
reverse("sso:login"),
|
||||
{"course": course_param},
|
||||
)
|
||||
|
||||
# THEN
|
||||
mock_oauth.signin.authorize_redirect.assert_called_once_with(
|
||||
ANY,
|
||||
"/sso/callback",
|
||||
state=expected_state_encoded,
|
||||
lang="de",
|
||||
)
|
||||
|
||||
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
def test_signin_with_state_param(self, mock_oauth):
|
||||
# GIVEN
|
||||
state_param = "vv"
|
||||
|
||||
expected_state = {"course": state_param, "next": None}
|
||||
expected_state_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(expected_state).encode()
|
||||
).decode()
|
||||
|
||||
mock_oauth.signin.authorize_redirect = MagicMock()
|
||||
mock_oauth.signin.authorize_redirect.return_value = redirect(
|
||||
"/just/here/to/return/a/redirect/object"
|
||||
)
|
||||
|
||||
# WHEN
|
||||
self.client.get(
|
||||
reverse("sso:login"),
|
||||
{"state": state_param},
|
||||
)
|
||||
|
||||
# THEN
|
||||
mock_oauth.signin.authorize_redirect.assert_called_once_with(
|
||||
ANY,
|
||||
"/sso/callback",
|
||||
state=expected_state_encoded,
|
||||
lang="de",
|
||||
)
|
||||
|
||||
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
def test_signin_next_url(self, mock_oauth):
|
||||
# GIVEN
|
||||
next_url = "/some/next/url"
|
||||
|
||||
expected_state = {"course": None, "next": next_url}
|
||||
expected_state_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(expected_state).encode()
|
||||
).decode()
|
||||
|
||||
mock_oauth.signin.authorize_redirect = MagicMock()
|
||||
mock_oauth.signin.authorize_redirect.return_value = redirect(
|
||||
"/just/here/to/return/a/redirect/object"
|
||||
)
|
||||
|
||||
# WHEN
|
||||
self.client.get(
|
||||
reverse("sso:login"),
|
||||
{"next": next_url},
|
||||
)
|
||||
|
||||
# THEN
|
||||
mock_oauth.signin.authorize_redirect.assert_called_once_with(
|
||||
ANY,
|
||||
"/sso/callback",
|
||||
state=expected_state_encoded,
|
||||
lang=ANY,
|
||||
)
|
||||
|
||||
@override_settings(OAUTH_SIGNIN_REDIRECT_URI="/sso/callback")
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
def test_signin_language(self, mock_oauth):
|
||||
# GIVEN
|
||||
language = "fr"
|
||||
|
||||
mock_oauth.signin.authorize_redirect = MagicMock()
|
||||
mock_oauth.signin.authorize_redirect.return_value = redirect(
|
||||
"/just/here/to/return/a/redirect/object"
|
||||
)
|
||||
|
||||
# WHEN
|
||||
self.client.get(
|
||||
reverse("sso:login"),
|
||||
{"lang": language},
|
||||
)
|
||||
|
||||
# THEN
|
||||
mock_oauth.signin.authorize_redirect.assert_called_once_with(
|
||||
ANY,
|
||||
"/sso/callback",
|
||||
state=ANY,
|
||||
lang=language,
|
||||
)
|
||||
|
||||
|
||||
class TestSignUp(TestCase):
|
||||
@override_settings(OAUTH_SIGNUP_REDIRECT_URI="/sso/login")
|
||||
@patch("vbv_lernwelt.sso.views.oauth")
|
||||
def test_signup_with_course_param(self, mock_oauth):
|
||||
# GIVEN
|
||||
course_param = "vv"
|
||||
language = "fr"
|
||||
|
||||
expected_state = {"course": course_param, "next": None}
|
||||
expected_state_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(expected_state).encode()
|
||||
).decode()
|
||||
|
||||
mock_oauth.signup.authorize_redirect = MagicMock()
|
||||
mock_oauth.signup.authorize_redirect.return_value = redirect(
|
||||
"/just/here/to/return/a/redirect/object"
|
||||
)
|
||||
|
||||
# WHEN
|
||||
self.client.get(
|
||||
reverse("sso:signup"),
|
||||
{"course": course_param, "lang": language},
|
||||
)
|
||||
|
||||
# THEN
|
||||
mock_oauth.signup.authorize_redirect.assert_called_once_with(
|
||||
ANY,
|
||||
"/sso/login",
|
||||
state=course_param,
|
||||
lang=language,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import base64
|
||||
import json
|
||||
|
||||
import structlog as structlog
|
||||
from authlib.integrations.base_client import OAuthError
|
||||
from django.conf import settings
|
||||
|
|
@ -5,6 +8,7 @@ from django.contrib.auth import login as dj_login
|
|||
from django.shortcuts import redirect
|
||||
from sentry_sdk import capture_exception
|
||||
|
||||
from vbv_lernwelt.core.models import User
|
||||
from vbv_lernwelt.course.models import CourseSession
|
||||
from vbv_lernwelt.course_session.utils import has_course_session_user_vv
|
||||
from vbv_lernwelt.importer.services import create_or_update_user
|
||||
|
|
@ -13,24 +17,49 @@ from vbv_lernwelt.sso.jwt import decode_jwt
|
|||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
OAUTH_FAIL_REDIRECT = "login-error"
|
||||
|
||||
|
||||
def signup(request):
|
||||
course = request.GET.get("course")
|
||||
|
||||
redirect_uri = settings.OAUTH_SIGNUP_REDIRECT_URI
|
||||
logger.debug(f"SSO Signup (course={course})", sso_signup_redirect_uri=redirect_uri)
|
||||
return oauth.signup.authorize_redirect(request, redirect_uri, state=course)
|
||||
|
||||
return oauth.signup.authorize_redirect(
|
||||
request, redirect_uri, state=course, lang=request.GET.get("lang", "de")
|
||||
)
|
||||
|
||||
|
||||
def signin(request):
|
||||
# course query OR state when coming from signup (oauth)
|
||||
course = request.GET.get("course", request.GET.get("state"))
|
||||
course_param = request.GET.get("course", request.GET.get("state"))
|
||||
next_param = request.GET.get("next")
|
||||
|
||||
state_json = json.dumps({"course": course_param, "next": next_param})
|
||||
state_encoded = base64.urlsafe_b64encode(state_json.encode()).decode()
|
||||
|
||||
redirect_uri = settings.OAUTH_SIGNIN_REDIRECT_URI
|
||||
logger.info(f"SSO Login (course={course})", sso_login_redirect_uri=redirect_uri)
|
||||
return oauth.signin.authorize_redirect(
|
||||
request, redirect_uri, state=course, lang=request.GET.get("lang", "de")
|
||||
|
||||
logger.info(
|
||||
f"SSO Login (course={course_param}, next={next_param})",
|
||||
sso_login_redirect_uri=redirect_uri,
|
||||
)
|
||||
return oauth.signin.authorize_redirect(
|
||||
request, redirect_uri, state=state_encoded, lang=request.GET.get("lang", "de")
|
||||
)
|
||||
|
||||
|
||||
def get_redirect_uri(user: User, course: str, next_url: str):
|
||||
if course == "vv" and not has_course_session_user_vv(user):
|
||||
return redirect("/onboarding/vv/account/create")
|
||||
elif (
|
||||
course == "uk"
|
||||
and not CourseSession.objects.filter(coursesessionuser__user=user).exists()
|
||||
):
|
||||
return redirect("/onboarding/uk/account/create")
|
||||
elif next_url:
|
||||
return redirect(next_url)
|
||||
|
||||
return redirect("/")
|
||||
|
||||
|
||||
def authorize_signin(request):
|
||||
|
|
@ -40,13 +69,19 @@ def authorize_signin(request):
|
|||
logger.error(e, exc_info=True, label="sso")
|
||||
if not settings.DEBUG:
|
||||
capture_exception(e)
|
||||
return redirect(f"/{OAUTH_FAIL_REDIRECT}?state=oautherror")
|
||||
return redirect("/")
|
||||
|
||||
id_token = decode_jwt(jwt_token["id_token"])
|
||||
course = request.GET.get("state")
|
||||
|
||||
state = json.loads(
|
||||
base64.urlsafe_b64decode(request.GET.get("state").encode()).decode()
|
||||
)
|
||||
|
||||
course = state.get("course")
|
||||
next_url = state.get("next")
|
||||
|
||||
logger.debug(
|
||||
f"SSO Authorize (course={course})",
|
||||
f"SSO Authorize (course={course}, next={next_url}",
|
||||
sso_authorize_id_token=id_token,
|
||||
)
|
||||
|
||||
|
|
@ -59,13 +94,4 @@ def authorize_signin(request):
|
|||
|
||||
dj_login(request, user)
|
||||
|
||||
# figure out where to redirect to (onboarding or home)
|
||||
if course == "vv" and not has_course_session_user_vv(user):
|
||||
return redirect("/onboarding/vv/account/create")
|
||||
elif (
|
||||
course == "uk"
|
||||
and not CourseSession.objects.filter(coursesessionuser__user=user).exists()
|
||||
):
|
||||
return redirect("/onboarding/uk/account/create")
|
||||
else:
|
||||
return redirect("/")
|
||||
return get_redirect_uri(user=user, course=course, next_url=next_url)
|
||||
|
|
|
|||
Loading…
Reference in New Issue