78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
# https://docs.authlib.org/en/latest/client/frameworks.html#frameworks-clients
|
|
import base64
|
|
import json
|
|
from time import mktime
|
|
from django.contrib.auth import get_user_model
|
|
from django.db import models
|
|
from django.utils import timezone
|
|
|
|
from oauth.managers import OAuth2TokenManager
|
|
|
|
from core.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class OAuth2Token(models.Model):
|
|
token_type = models.CharField(max_length=40)
|
|
access_token = models.TextField()
|
|
refresh_token = models.TextField()
|
|
expires_at = models.PositiveIntegerField()
|
|
user = models.OneToOneField(get_user_model(), on_delete=models.CASCADE)
|
|
|
|
objects = OAuth2TokenManager()
|
|
|
|
@classmethod
|
|
def is_dict_expired(cls, token_dict):
|
|
return cls.has_timestamp_expired(token_dict['expires_at'])
|
|
|
|
@classmethod
|
|
def has_timestamp_expired(cls, expires_at):
|
|
now = timezone.now()
|
|
now_unix_timestamp = int(mktime(now.timetuple()))
|
|
|
|
return expires_at < now_unix_timestamp
|
|
|
|
@classmethod
|
|
def update_dict_with_refresh_data(cls, data, old_access_token):
|
|
try:
|
|
token = cls.objects.get(access_token=old_access_token)
|
|
except cls.DoesNotExist:
|
|
return False
|
|
|
|
return token.update_with_refresh_data(data)
|
|
|
|
def to_token(self):
|
|
return dict(
|
|
access_token=self.access_token,
|
|
token_type=self.token_type,
|
|
refresh_token=self.refresh_token,
|
|
expires_at=self.expires_at,
|
|
)
|
|
|
|
def has_expired(self):
|
|
return OAuth2Token.has_timestamp_expired(self.expires_at)
|
|
|
|
def update_with_refresh_data(self, data):
|
|
payload = self._jwt_payload(data['access_token'])
|
|
if not payload:
|
|
return None, False
|
|
|
|
self.token_type = data['token_type']
|
|
self.access_token = data['access_token']
|
|
self.refresh_token = data['refresh_token']
|
|
|
|
self.expires_at = int(payload['exp'])
|
|
self.save()
|
|
|
|
return self, True
|
|
|
|
def _jwt_payload(self, jwt):
|
|
jwt_parts = jwt.split('.')
|
|
try:
|
|
payload_bytes = base64.b64decode(jwt_parts[1])
|
|
payload = json.loads(payload_bytes.decode("UTF-8"))
|
|
except Exception as e:
|
|
logger.warning(f'OAuthToken error: Could not decode jwt: {e}')
|
|
return None
|
|
return payload
|