skillbox/server/oauth/models.py

78 lines
2.3 KiB
Python

# https://docs.authlib.org/en/latest/client/frameworks.html#frameworks-clients
import base64
import json
from time import mktime
from django.contrib.auth import get_user_model
from django.db import models
from django.utils import timezone
from oauth.managers import OAuth2TokenManager
from core.logger import get_logger
logger = get_logger(__name__)
class OAuth2Token(models.Model):
token_type = models.CharField(max_length=40)
access_token = models.TextField()
refresh_token = models.TextField()
expires_at = models.PositiveIntegerField()
user = models.OneToOneField(get_user_model(), on_delete=models.CASCADE)
objects = OAuth2TokenManager()
@classmethod
def is_dict_expired(cls, token_dict):
return cls.has_timestamp_expired(token_dict['expires_at'])
@classmethod
def has_timestamp_expired(cls, expires_at):
now = timezone.now()
now_unix_timestamp = int(mktime(now.timetuple()))
return expires_at < now_unix_timestamp
@classmethod
def update_dict_with_refresh_data(cls, data, old_access_token):
try:
token = cls.objects.get(access_token=old_access_token)
except cls.DoesNotExist:
return False
return token.update_with_refresh_data(data)
def to_token(self):
return dict(
access_token=self.access_token,
token_type=self.token_type,
refresh_token=self.refresh_token,
expires_at=self.expires_at,
)
def has_expired(self):
return OAuth2Token.has_timestamp_expired(self.expires_at)
def update_with_refresh_data(self, data):
payload = self._jwt_payload(data['access_token'])
if not payload:
return None, False
self.token_type = data['token_type']
self.access_token = data['access_token']
self.refresh_token = data['refresh_token']
self.expires_at = int(payload['exp'])
self.save()
return self, True
def _jwt_payload(self, jwt):
jwt_parts = jwt.split('.')
try:
payload_bytes = base64.b64decode(jwt_parts[1])
payload = json.loads(payload_bytes.decode("UTF-8"))
except Exception as e:
logger.warning(f'OAuthToken error: Could not decode jwt: {e}')
return None
return payload