Add refresh tests

This commit is contained in:
Christian Cueni 2021-06-17 16:53:53 +02:00
parent 198a0fa33c
commit 9b7c74e7f7
8 changed files with 89 additions and 35 deletions

View File

@ -400,6 +400,7 @@ ALLOW_BETA_LOGIN = True
HEP_URL = os.environ.get("HEP_URL")
# HEP Oauth
OAUTH_API_BASE_URL = os.environ.get("OAUTH_API_BASE_URL")
AUTHLIB_OAUTH_CLIENTS = {
'hep': {
'client_id': os.environ.get("OAUTH_CLIENT_ID"),
@ -410,7 +411,7 @@ AUTHLIB_OAUTH_CLIENTS = {
'access_token_params': None,
'refresh_token_url': None,
'authorize_url': os.environ.get("OAUTH_AUTHORIZE_URL"),
'api_base_url': os.environ.get("OAUTH_API_BASE_URL"),
'api_base_url': OAUTH_API_BASE_URL,
'client_kwargs': {
'scope': 'orders',
'token_endpoint_auth_method': 'client_secret_post',

View File

@ -7,7 +7,7 @@ from django.utils import timezone
from oauth.models import OAuth2Token
IN_A_HOUR = timezone.now() + timedelta(hours=1)
IN_A_HOUR_UNIX = time.mktime(IN_A_HOUR.timetuple())
IN_A_HOUR_UNIX = int(time.mktime(IN_A_HOUR.timetuple()))
class Oauth2TokenFactory(factory.django.DjangoModelFactory):

View File

@ -54,12 +54,13 @@ class HepClient:
if not token_dict:
raise HepClientNoTokenException
if not OAuth2Token.is_dict_valid(token_dict):
token, refresh_success = self._refresh_token(token_dict)
if OAuth2Token.is_dict_expired(token_dict):
refresh_data = self._refresh_token(token_dict)
token, refresh_success = OAuth2Token.update_dict_with_refresh_data(refresh_data, token_dict['access_token'])
if not refresh_success:
raise HepClientUnauthorizedException
token_dict = token.to_dict()
token_dict = token.to_token()
return token_dict
@ -67,7 +68,7 @@ class HepClient:
data = {
'grant_type': 'refresh_token',
'refresh_token': token_dict.refresh_token,
'refresh_token': token_dict['refresh_token'],
'client_id': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_id'],
'client_secret': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_secret'],
'scope': ''

View File

@ -19,22 +19,20 @@ class OAuth2Token(models.Model):
objects = OAuth2TokenManager()
@classmethod
def is_dict_valid(cls, token_dict):
try:
token = cls.objects.get(access_token=token_dict['access_token'])
except cls.DoesNotExist:
return False
"""
if the refresh frequency is increased it might be better to also return the token to avoid any additional
token fetch db calls
"""
return token.is_valid()
def is_dict_expired(cls, token_dict):
return cls.has_timestamp_expired(token_dict['expires_at'])
@classmethod
def update_dict_with_refresh_data(cls, data):
def has_timestamp_expired(cls, expires_at):
now = timezone.now()
now_unix_timestamp = int(mktime(now.timetuple()))
return expires_at < now_unix_timestamp
@classmethod
def update_dict_with_refresh_data(cls, data, old_access_token):
try:
token = cls.objects.get(access_token=data['access_token'])
token = cls.objects.get(access_token=old_access_token)
except cls.DoesNotExist:
return False
@ -48,11 +46,8 @@ class OAuth2Token(models.Model):
expires_at=self.expires_at,
)
def is_valid(self):
now = timezone.now()
now_unix_timestamp = int(mktime(now.timetuple()))
return self.expires_at > now_unix_timestamp
def has_expired(self):
return OAuth2Token.has_timestamp_expired(self.expires_at)
def update_with_refresh_data(self, data):
payload = self._jwt_payload(data['access_token'])

View File

@ -1,8 +1,6 @@
from unittest.mock import patch
import requests
from authlib.integrations.base_client import BaseApp
from django.contrib.auth.models import AnonymousUser
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, RequestFactory
from graphene.test import Client
@ -10,7 +8,7 @@ from graphene.test import Client
from api.schema import schema
from core.factories import UserFactory
from oauth.factories import Oauth2TokenFactory
from users.tests.mock_hep_data_factory import MockResponse, VALID_TEACHERS_ORDERS
from users.tests.mock_hep_data_factory import MockResponse
from users.models import License, Role, SchoolClass
REDEEM_MYSKILLBOX_SUCCESS_RESPONSE = {

View File

@ -1,10 +1,17 @@
import json
from datetime import datetime, timedelta
from unittest.mock import patch
import requests
from django.test import TestCase
from oauth.hep_client import HepClient
from core.factories import UserFactory
from oauth.factories import Oauth2TokenFactory
from oauth.hep_client import HepClient, HepClientUnauthorizedException, HepClientNoTokenException, HepClientException
from oauth.models import OAuth2Token
from oauth.tests.test_oauth2token import REFRESH_DATA
from users.licenses import MYSKILLBOX_LICENSES
from users.models import License
from users.tests.mock_hep_data_factory import MockResponse
ISBNS = list(MYSKILLBOX_LICENSES.keys())
@ -20,6 +27,12 @@ class HepClientTestCases(TestCase):
self.hep_client = HepClient()
self.now = datetime.now()
def _create_token(self):
user = UserFactory(username="bert")
token = Oauth2TokenFactory(user=user)
self.token_dict = token.to_token()
return self.token_dict
def test_has_no_valid_product(self):
products = [
{
@ -168,3 +181,49 @@ class HepClientTestCases(TestCase):
is_active = License.is_product_active(expiry_date, TEACHER_ISBN)
self.assertFalse(is_active)
def test_token_is_not_valid_when_token_and_request_empty(self):
try:
self.hep_client._get_valid_token(None, None)
except HepClientNoTokenException:
return
self.fail("HepClientTestCases.test_token_is_not_valid_when_token_and_request_empty: Should throw HepClientUnauthorizedException")
@patch.object(OAuth2Token, 'is_dict_expired', return_value=True)
@patch.object(requests, 'post', return_value=MockResponse(400, data={}))
def test_token_is_expired_and_cannot_be_refreshed_from_api(self, mock_fn1, mock_fn2):
user = UserFactory(username='housi')
token = Oauth2TokenFactory(user=user).to_token()
try:
self.hep_client._get_valid_token(None, token)
except HepClientException:
return
self.fail("HepClientTestCases.test_token_is_expired_and_cannot_be_refreshed_from_api: Should throw HepClientUnauthorizedException")
@patch.object(OAuth2Token, 'is_dict_expired', return_value=True)
@patch.object(requests, 'post', return_value=MockResponse(200, data={}))
@patch.object(OAuth2Token, 'update_dict_with_refresh_data', return_value=(None, False))
def test_token_is_expired_and_cannot_be_refreshed(self, mock_fn1, mock_fn2, mock_fn3):
user = UserFactory(username='housi')
token = Oauth2TokenFactory(user=user).to_token()
try:
self.hep_client._get_valid_token(None, token)
except HepClientUnauthorizedException:
return
self.fail("HepClientTestCases.test_token_is_expired_and_cannot_be_refreshed: Should throw HepClientUnauthorizedException")
@patch.object(OAuth2Token, 'is_dict_expired', return_value=True)
@patch.object(requests, 'post', return_value=MockResponse(200, data=REFRESH_DATA))
def test_can_refresh_token(self, mock_fn1, mock_fn2):
user = UserFactory(username='housi')
token = Oauth2TokenFactory(user=user).to_token()
token_dict = self.hep_client._get_valid_token(None, token)
self.assertEqual(token_dict['access_token'], REFRESH_DATA['access_token'])
self.assertEqual(token_dict['refresh_token'], REFRESH_DATA['refresh_token'])

View File

@ -1,4 +1,3 @@
from datetime import timedelta
from django.test import TestCase
from django.utils import timezone
@ -22,13 +21,13 @@ class OAuth2TokenTestCases(TestCase):
self.now = timezone.now()
def test_is_valid(self):
self.assertTrue(self.token.is_valid())
self.assertFalse(self.token.has_expired())
def test_has_expired(self):
one_hourish_delta_in_ms = 60*60
self.token.expires_at -= one_hourish_delta_in_ms
self.token.save()
self.assertFalse(self.token.is_valid())
self.assertTrue(self.token.has_expired())
def test_can_update_refresh_data(self):
token, success = self.token.update_with_refresh_data(REFRESH_DATA)

View File

@ -8,6 +8,7 @@ from oauth.models import OAuth2Token
from oauth.user_signup_login_handler import handle_user_and_verify_products, EMAIL_NOT_VERIFIED, UNKNOWN_ERROR
from django.contrib.auth import login as dj_login
OAUTH_REDIRECT = 'oauth-redirect'
def login(request):
hep_oauth_client = oauth.create_client('hep')
@ -27,7 +28,7 @@ def authorize(request):
# logger
# sentry event
# rename
return redirect(f'/login-success?state={UNKNOWN_ERROR}')
return redirect(f'/{OAUTH_REDIRECT}?state={UNKNOWN_ERROR}')
if user and status_msg != EMAIL_NOT_VERIFIED:
dj_login(request, user)
@ -35,7 +36,7 @@ def authorize(request):
OAuth2Token.objects.update_or_create_token(token, user)
if status_msg:
return redirect(f'/login-success?state={status_msg}')
return redirect(f'/{OAUTH_REDIRECT}?state={status_msg}')
return redirect(f'/login-success?state=success')
return redirect(f'/{OAUTH_REDIRECT}?state=success')