Add refresh tests

This commit is contained in:
Christian Cueni 2021-06-17 16:53:53 +02:00
parent 198a0fa33c
commit 9b7c74e7f7
8 changed files with 89 additions and 35 deletions

View File

@ -400,6 +400,7 @@ ALLOW_BETA_LOGIN = True
HEP_URL = os.environ.get("HEP_URL") HEP_URL = os.environ.get("HEP_URL")
# HEP Oauth # HEP Oauth
OAUTH_API_BASE_URL = os.environ.get("OAUTH_API_BASE_URL")
AUTHLIB_OAUTH_CLIENTS = { AUTHLIB_OAUTH_CLIENTS = {
'hep': { 'hep': {
'client_id': os.environ.get("OAUTH_CLIENT_ID"), 'client_id': os.environ.get("OAUTH_CLIENT_ID"),
@ -410,7 +411,7 @@ AUTHLIB_OAUTH_CLIENTS = {
'access_token_params': None, 'access_token_params': None,
'refresh_token_url': None, 'refresh_token_url': None,
'authorize_url': os.environ.get("OAUTH_AUTHORIZE_URL"), 'authorize_url': os.environ.get("OAUTH_AUTHORIZE_URL"),
'api_base_url': os.environ.get("OAUTH_API_BASE_URL"), 'api_base_url': OAUTH_API_BASE_URL,
'client_kwargs': { 'client_kwargs': {
'scope': 'orders', 'scope': 'orders',
'token_endpoint_auth_method': 'client_secret_post', 'token_endpoint_auth_method': 'client_secret_post',

View File

@ -7,7 +7,7 @@ from django.utils import timezone
from oauth.models import OAuth2Token from oauth.models import OAuth2Token
IN_A_HOUR = timezone.now() + timedelta(hours=1) IN_A_HOUR = timezone.now() + timedelta(hours=1)
IN_A_HOUR_UNIX = time.mktime(IN_A_HOUR.timetuple()) IN_A_HOUR_UNIX = int(time.mktime(IN_A_HOUR.timetuple()))
class Oauth2TokenFactory(factory.django.DjangoModelFactory): class Oauth2TokenFactory(factory.django.DjangoModelFactory):

View File

@ -54,12 +54,13 @@ class HepClient:
if not token_dict: if not token_dict:
raise HepClientNoTokenException raise HepClientNoTokenException
if not OAuth2Token.is_dict_valid(token_dict): if OAuth2Token.is_dict_expired(token_dict):
token, refresh_success = self._refresh_token(token_dict) refresh_data = self._refresh_token(token_dict)
token, refresh_success = OAuth2Token.update_dict_with_refresh_data(refresh_data, token_dict['access_token'])
if not refresh_success: if not refresh_success:
raise HepClientUnauthorizedException raise HepClientUnauthorizedException
token_dict = token.to_dict() token_dict = token.to_token()
return token_dict return token_dict
@ -67,7 +68,7 @@ class HepClient:
data = { data = {
'grant_type': 'refresh_token', 'grant_type': 'refresh_token',
'refresh_token': token_dict.refresh_token, 'refresh_token': token_dict['refresh_token'],
'client_id': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_id'], 'client_id': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_id'],
'client_secret': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_secret'], 'client_secret': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_secret'],
'scope': '' 'scope': ''

View File

@ -19,22 +19,20 @@ class OAuth2Token(models.Model):
objects = OAuth2TokenManager() objects = OAuth2TokenManager()
@classmethod @classmethod
def is_dict_valid(cls, token_dict): def is_dict_expired(cls, token_dict):
try: return cls.has_timestamp_expired(token_dict['expires_at'])
token = cls.objects.get(access_token=token_dict['access_token'])
except cls.DoesNotExist:
return False
"""
if the refresh frequency is increased it might be better to also return the token to avoid any additional
token fetch db calls
"""
return token.is_valid()
@classmethod @classmethod
def update_dict_with_refresh_data(cls, data): def has_timestamp_expired(cls, expires_at):
now = timezone.now()
now_unix_timestamp = int(mktime(now.timetuple()))
return expires_at < now_unix_timestamp
@classmethod
def update_dict_with_refresh_data(cls, data, old_access_token):
try: try:
token = cls.objects.get(access_token=data['access_token']) token = cls.objects.get(access_token=old_access_token)
except cls.DoesNotExist: except cls.DoesNotExist:
return False return False
@ -48,11 +46,8 @@ class OAuth2Token(models.Model):
expires_at=self.expires_at, expires_at=self.expires_at,
) )
def is_valid(self): def has_expired(self):
now = timezone.now() return OAuth2Token.has_timestamp_expired(self.expires_at)
now_unix_timestamp = int(mktime(now.timetuple()))
return self.expires_at > now_unix_timestamp
def update_with_refresh_data(self, data): def update_with_refresh_data(self, data):
payload = self._jwt_payload(data['access_token']) payload = self._jwt_payload(data['access_token'])

View File

@ -1,8 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import requests
from authlib.integrations.base_client import BaseApp from authlib.integrations.base_client import BaseApp
from django.contrib.auth.models import AnonymousUser
from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, RequestFactory from django.test import TestCase, RequestFactory
from graphene.test import Client from graphene.test import Client
@ -10,7 +8,7 @@ from graphene.test import Client
from api.schema import schema from api.schema import schema
from core.factories import UserFactory from core.factories import UserFactory
from oauth.factories import Oauth2TokenFactory from oauth.factories import Oauth2TokenFactory
from users.tests.mock_hep_data_factory import MockResponse, VALID_TEACHERS_ORDERS from users.tests.mock_hep_data_factory import MockResponse
from users.models import License, Role, SchoolClass from users.models import License, Role, SchoolClass
REDEEM_MYSKILLBOX_SUCCESS_RESPONSE = { REDEEM_MYSKILLBOX_SUCCESS_RESPONSE = {

View File

@ -1,10 +1,17 @@
import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import patch
import requests
from django.test import TestCase from django.test import TestCase
from oauth.hep_client import HepClient
from core.factories import UserFactory
from oauth.factories import Oauth2TokenFactory
from oauth.hep_client import HepClient, HepClientUnauthorizedException, HepClientNoTokenException, HepClientException
from oauth.models import OAuth2Token
from oauth.tests.test_oauth2token import REFRESH_DATA
from users.licenses import MYSKILLBOX_LICENSES from users.licenses import MYSKILLBOX_LICENSES
from users.models import License from users.models import License
from users.tests.mock_hep_data_factory import MockResponse
ISBNS = list(MYSKILLBOX_LICENSES.keys()) ISBNS = list(MYSKILLBOX_LICENSES.keys())
@ -20,6 +27,12 @@ class HepClientTestCases(TestCase):
self.hep_client = HepClient() self.hep_client = HepClient()
self.now = datetime.now() self.now = datetime.now()
def _create_token(self):
user = UserFactory(username="bert")
token = Oauth2TokenFactory(user=user)
self.token_dict = token.to_token()
return self.token_dict
def test_has_no_valid_product(self): def test_has_no_valid_product(self):
products = [ products = [
{ {
@ -168,3 +181,49 @@ class HepClientTestCases(TestCase):
is_active = License.is_product_active(expiry_date, TEACHER_ISBN) is_active = License.is_product_active(expiry_date, TEACHER_ISBN)
self.assertFalse(is_active) self.assertFalse(is_active)
def test_token_is_not_valid_when_token_and_request_empty(self):
try:
self.hep_client._get_valid_token(None, None)
except HepClientNoTokenException:
return
self.fail("HepClientTestCases.test_token_is_not_valid_when_token_and_request_empty: Should throw HepClientUnauthorizedException")
@patch.object(OAuth2Token, 'is_dict_expired', return_value=True)
@patch.object(requests, 'post', return_value=MockResponse(400, data={}))
def test_token_is_expired_and_cannot_be_refreshed_from_api(self, mock_fn1, mock_fn2):
user = UserFactory(username='housi')
token = Oauth2TokenFactory(user=user).to_token()
try:
self.hep_client._get_valid_token(None, token)
except HepClientException:
return
self.fail("HepClientTestCases.test_token_is_expired_and_cannot_be_refreshed_from_api: Should throw HepClientUnauthorizedException")
@patch.object(OAuth2Token, 'is_dict_expired', return_value=True)
@patch.object(requests, 'post', return_value=MockResponse(200, data={}))
@patch.object(OAuth2Token, 'update_dict_with_refresh_data', return_value=(None, False))
def test_token_is_expired_and_cannot_be_refreshed(self, mock_fn1, mock_fn2, mock_fn3):
user = UserFactory(username='housi')
token = Oauth2TokenFactory(user=user).to_token()
try:
self.hep_client._get_valid_token(None, token)
except HepClientUnauthorizedException:
return
self.fail("HepClientTestCases.test_token_is_expired_and_cannot_be_refreshed: Should throw HepClientUnauthorizedException")
@patch.object(OAuth2Token, 'is_dict_expired', return_value=True)
@patch.object(requests, 'post', return_value=MockResponse(200, data=REFRESH_DATA))
def test_can_refresh_token(self, mock_fn1, mock_fn2):
user = UserFactory(username='housi')
token = Oauth2TokenFactory(user=user).to_token()
token_dict = self.hep_client._get_valid_token(None, token)
self.assertEqual(token_dict['access_token'], REFRESH_DATA['access_token'])
self.assertEqual(token_dict['refresh_token'], REFRESH_DATA['refresh_token'])

View File

@ -1,4 +1,3 @@
from datetime import timedelta
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
@ -22,13 +21,13 @@ class OAuth2TokenTestCases(TestCase):
self.now = timezone.now() self.now = timezone.now()
def test_is_valid(self): def test_is_valid(self):
self.assertTrue(self.token.is_valid()) self.assertFalse(self.token.has_expired())
def test_has_expired(self): def test_has_expired(self):
one_hourish_delta_in_ms = 60*60 one_hourish_delta_in_ms = 60*60
self.token.expires_at -= one_hourish_delta_in_ms self.token.expires_at -= one_hourish_delta_in_ms
self.token.save() self.token.save()
self.assertFalse(self.token.is_valid()) self.assertTrue(self.token.has_expired())
def test_can_update_refresh_data(self): def test_can_update_refresh_data(self):
token, success = self.token.update_with_refresh_data(REFRESH_DATA) token, success = self.token.update_with_refresh_data(REFRESH_DATA)

View File

@ -8,6 +8,7 @@ from oauth.models import OAuth2Token
from oauth.user_signup_login_handler import handle_user_and_verify_products, EMAIL_NOT_VERIFIED, UNKNOWN_ERROR from oauth.user_signup_login_handler import handle_user_and_verify_products, EMAIL_NOT_VERIFIED, UNKNOWN_ERROR
from django.contrib.auth import login as dj_login from django.contrib.auth import login as dj_login
OAUTH_REDIRECT = 'oauth-redirect'
def login(request): def login(request):
hep_oauth_client = oauth.create_client('hep') hep_oauth_client = oauth.create_client('hep')
@ -27,7 +28,7 @@ def authorize(request):
# logger # logger
# sentry event # sentry event
# rename # rename
return redirect(f'/login-success?state={UNKNOWN_ERROR}') return redirect(f'/{OAUTH_REDIRECT}?state={UNKNOWN_ERROR}')
if user and status_msg != EMAIL_NOT_VERIFIED: if user and status_msg != EMAIL_NOT_VERIFIED:
dj_login(request, user) dj_login(request, user)
@ -35,7 +36,7 @@ def authorize(request):
OAuth2Token.objects.update_or_create_token(token, user) OAuth2Token.objects.update_or_create_token(token, user)
if status_msg: if status_msg:
return redirect(f'/login-success?state={status_msg}') return redirect(f'/{OAUTH_REDIRECT}?state={status_msg}')
return redirect(f'/login-success?state=success') return redirect(f'/{OAUTH_REDIRECT}?state=success')