79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
# https://docs.authlib.org/en/latest/client/frameworks.html#frameworks-clients
|
|
import base64
|
|
import json
|
|
from time import mktime
|
|
from django.contrib.auth import get_user_model
|
|
from django.db import models
|
|
from django.utils import timezone
|
|
|
|
from oauth.managers import OAuth2TokenManager
|
|
|
|
|
|
class OAuth2Token(models.Model):
|
|
token_type = models.CharField(max_length=40)
|
|
access_token = models.TextField()
|
|
refresh_token = models.TextField()
|
|
expires_at = models.PositiveIntegerField()
|
|
user = models.OneToOneField(get_user_model(), on_delete=models.CASCADE)
|
|
|
|
objects = OAuth2TokenManager()
|
|
|
|
@classmethod
|
|
def is_dict_valid(cls, token_dict):
|
|
try:
|
|
token = cls.objects.get(access_token=token_dict['access_token'])
|
|
except cls.DoesNotExist:
|
|
return False
|
|
|
|
"""
|
|
if the refresh frequency is increased it might be better to also return the token to avoid any additional
|
|
token fetch db calls
|
|
"""
|
|
return token.is_valid()
|
|
|
|
@classmethod
|
|
def update_dict_with_refresh_data(cls, data):
|
|
try:
|
|
token = cls.objects.get(access_token=data['access_token'])
|
|
except cls.DoesNotExist:
|
|
return False
|
|
|
|
return token.update_with_refresh_data(data)
|
|
|
|
def to_token(self):
|
|
return dict(
|
|
access_token=self.access_token,
|
|
token_type=self.token_type,
|
|
refresh_token=self.refresh_token,
|
|
expires_at=self.expires_at,
|
|
)
|
|
|
|
def is_valid(self):
|
|
now = timezone.now()
|
|
now_unix_timestamp = int(mktime(now.timetuple()))
|
|
|
|
return self.expires_at > now_unix_timestamp
|
|
|
|
def update_with_refresh_data(self, data):
|
|
payload = self._jwt_payload(data['access_token'])
|
|
if not payload:
|
|
return None, False
|
|
|
|
self.token_type = data['token_type']
|
|
self.access_token = data['access_token']
|
|
self.refresh_token = data['refresh_token']
|
|
|
|
self.expires_at = int(payload['exp'])
|
|
self.save()
|
|
|
|
return self, True
|
|
|
|
def _jwt_payload(self, jwt):
|
|
jwt_parts = jwt.split('.')
|
|
try:
|
|
payload_bytes = base64.b64decode(jwt_parts[1])
|
|
payload = json.loads(payload_bytes.decode("UTF-8"))
|
|
except:
|
|
return None
|
|
return payload
|