diff --git a/server/oauth/hep_client.py b/server/oauth/hep_client.py index cf401e58..6cc073c9 100644 --- a/server/oauth/hep_client.py +++ b/server/oauth/hep_client.py @@ -1,9 +1,12 @@ +import requests +from django.conf import settings from django.utils.dateparse import parse_datetime from datetime import timedelta import logging -from oauth.oauth_client import oauth +from oauth.models import OAuth2Token +from oauth.oauth_client import oauth, fetch_token from users.licenses import MYSKILLBOX_LICENSES, is_myskillbox_product, TEACHER_KEY from users.models import License @@ -23,20 +26,18 @@ class HepClientNoTokenException(Exception): class HepClient: - def _call(self, url, request, token, method='get', data=None): - - token_parameters = { - 'token': token, - 'request': request - } + def _call(self, url, token_dict, method='get', data=None): if method == 'post': - response = oauth.hep.post(url, json=data, **token_parameters) + response = oauth.hep.post(url, json=data, token=token_dict) elif method == 'get': - response = oauth.hep.get(url, params=data, **token_parameters) + response = oauth.hep.get(url, params=data, token=token_dict) elif method == 'put': - response = oauth.hep.put(url, data=data, **token_parameters) + response = oauth.hep.put(url, data=data, token=token_dict) + return self._handle_response(response) + + def _handle_response(self, response): if response.status_code == 401: raise HepClientUnauthorizedException(response.status_code, response.json()) elif response.status_code != 200: @@ -44,29 +45,60 @@ class HepClient: return response + def _get_valid_token(self, request, token_dict): + if request is None and token_dict is None: + raise HepClientNoTokenException + + if not token_dict: + token_dict = token_dict('', request) + if not token_dict: + raise HepClientNoTokenException + + if not OAuth2Token.is_dict_valid(token_dict): + token, refresh_success = self._refresh_token(token_dict) + if not refresh_success: + raise HepClientUnauthorizedException + + token_dict = token.to_dict() + + return token_dict + + def _refresh_token(self, token_dict): + + data = { + 'grant_type': 'refresh_token', + 'refresh_token': token_dict.refresh_token, + 'client_id': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_id'], + 'client_secret': settings.AUTHLIB_OAUTH_CLIENTS['hep']['client_secret'], + 'scope': '' + } + + response = requests.post(f'{settings.OAUTH_API_BASE_URL}/oauth/token', json=data) + return self._handle_response(response).json() + def is_email_verified(self, user_data): return user_data['email_verified_at'] is not None - def user_details(self, request=None, token=None): - self._has_credentials(request, token) - response = self._call('api/auth/user', request, token) + def user_details(self, request=None, token_dict=None): + token_dict = self._get_valid_token(request, token_dict) + response = self._call('api/auth/user', token_dict) return response.json()['data'] - def logout(self, request=None, token=None): - self._has_credentials(request, token) - self._call('api/auth/logout', request, token, method='post') + def logout(self, request=None, token_dict=None): + token_dict = self._get_valid_token(request, token_dict) + self._call('api/auth/logout', token_dict, method='post') return True - def fetch_eorders(self, request=None, token=None): - self._has_credentials(request, token) + def fetch_eorders(self, request=None, token_dict=None): + token_dict = self._get_valid_token(request, token_dict) data = { 'filters[product_type]': 'eLehrmittel', } - response = self._call('api/partners/users/orders/search', request, token, data=data) + response = self._call('api/partners/users/orders/search', token_dict, data=data) return response.json()['data'] - def active_myskillbox_product_for_customer(self, request=None, token=None): - eorders = self.fetch_eorders(request=request, token=token) + def active_myskillbox_product_for_customer(self, request=None, token_dict=None): + eorders = self.fetch_eorders(request=request, token_dict=token_dict) myskillbox_products = self._extract_myskillbox_products(eorders) if len(myskillbox_products) == 0: @@ -74,14 +106,10 @@ class HepClient: else: return self._get_active_product(myskillbox_products) - def _has_credentials(self, request, token): - if request is None and token is None: - raise HepClientNoTokenException - - def redeem_coupon(self, coupon_code, customer_id, request=None, token=None): - self._has_credentials(request, token) + def redeem_coupon(self, coupon_code, customer_id, request=None, token_dict=None): + token_dict = self._get_valid_token(request, token_dict) try: - response = self._call(f'api/partners/users/{customer_id}/coupons/redeem', request, token, method='post', + response = self._call(f'api/partners/users/{customer_id}/coupons/redeem', token_dict, method='post', data={'code': coupon_code}) except HepClientException: return None @@ -152,4 +180,3 @@ class HepClient: # select a student product, as they are all valid it does not matter which one return active_products[0] - diff --git a/server/oauth/models.py b/server/oauth/models.py index c0063613..fbcfeed6 100644 --- a/server/oauth/models.py +++ b/server/oauth/models.py @@ -1,6 +1,10 @@ # https://docs.authlib.org/en/latest/client/frameworks.html#frameworks-clients +import base64 +import json +from time import mktime from django.contrib.auth import get_user_model from django.db import models +from django.utils import timezone from oauth.managers import OAuth2TokenManager @@ -14,6 +18,28 @@ class OAuth2Token(models.Model): objects = OAuth2TokenManager() + @classmethod + def is_dict_valid(cls, token_dict): + try: + token = cls.objects.get(access_token=token_dict['access_token']) + except cls.DoesNotExist: + return False + + """ + if the refresh frequency is increased it might be better to also return the token to avoid any additional + token fetch db calls + """ + return token.is_valid() + + @classmethod + def update_dict_with_refresh_data(cls, data): + try: + token = cls.objects.get(access_token=data['access_token']) + except cls.DoesNotExist: + return False + + return token.update_with_refresh_data(data) + def to_token(self): return dict( access_token=self.access_token, @@ -21,3 +47,32 @@ class OAuth2Token(models.Model): refresh_token=self.refresh_token, expires_at=self.expires_at, ) + + def is_valid(self): + now = timezone.now() + now_unix_timestamp = int(mktime(now.timetuple())) + + return self.expires_at > now_unix_timestamp + + def update_with_refresh_data(self, data): + self.token_type = data['token_type'] + self.access_token = data['access_token'] + self.refresh_token = data['refresh_token'] + + payload = self._jwt_payload(data['access_token']) + if not payload: + return None, False + + self.expires_at = int(payload['exp']) + self.save() + + return self, True + + def _jwt_payload(self, jwt): + jwt_parts = jwt.split('.') + payload_bytes = base64.b64decode(jwt_parts[1]) + try: + payload = json.loads(payload_bytes.decode("UTF-8")) + except: + return None + return payload diff --git a/server/oauth/views.py b/server/oauth/views.py index d098f28f..ab67ba85 100644 --- a/server/oauth/views.py +++ b/server/oauth/views.py @@ -24,6 +24,9 @@ def authorize(request): user, status_msg = handle_user_and_verify_products(user_data, token) user.sync_with_hep_data(user_data) except OAuthError as e: + # logger + # sentry event + # rename return redirect(f'/login-success?state={UNKNOWN_ERROR}') if user and status_msg != EMAIL_NOT_VERIFIED: