Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Authlib / flask / oauth2 / sqla.py
Size: Mime:
import time
import json
from sqlalchemy import Column, String, Boolean, Text, Integer
from sqlalchemy.ext.hybrid import hybrid_property
from authlib.oauth2.rfc6749 import (
    ClientMixin,
    TokenMixin,
    AuthorizationCodeMixin,
)
from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope
from authlib.oidc.core import (
    AuthorizationCodeMixin as OIDCCodeMixin
)
from authlib.deprecate import deprecate

deprecate('Deprecate "authlib.flask.oauth2.sqla"', '1.0', 'Jeclj', 'sq')


class OAuth2ClientMixin(ClientMixin):
    client_id = Column(String(48), index=True)
    client_secret = Column(String(120))
    issued_at = Column(
        Integer, nullable=False,
        default=lambda: int(time.time())
    )
    expires_at = Column(Integer, nullable=False, default=0)

    redirect_uri = Column(Text)
    token_endpoint_auth_method = Column(
        String(48), default='client_secret_basic')
    grant_type = Column(Text, nullable=False, default='')
    response_type = Column(Text, nullable=False, default='')
    scope = Column(Text, nullable=False, default='')

    client_name = Column(String(100))
    client_uri = Column(Text)
    logo_uri = Column(Text)
    contact = Column(Text)
    tos_uri = Column(Text)
    policy_uri = Column(Text)
    jwks_uri = Column(Text)
    jwks_text = Column(Text)
    i18n_metadata = Column(Text)

    software_id = Column(String(36))
    software_version = Column(String(48))

    def __repr__(self):
        return '<Client: {}>'.format(self.client_id)

    @hybrid_property
    def redirect_uris(self):
        if self.redirect_uri:
            return self.redirect_uri.splitlines()
        return []

    @redirect_uris.setter
    def redirect_uris(self, value):
        self.redirect_uri = '\n'.join(value)

    @hybrid_property
    def grant_types(self):
        if self.grant_type:
            return self.grant_type.splitlines()
        return []

    @grant_types.setter
    def grant_types(self, value):
        self.grant_type = '\n'.join(value)

    @hybrid_property
    def response_types(self):
        if self.response_type:
            return self.response_type.splitlines()
        return []

    @response_types.setter
    def response_types(self, value):
        self.response_type = '\n'.join(value)

    @hybrid_property
    def contacts(self):
        if self.contact:
            return json.loads(self.contact)
        return []

    @contacts.setter
    def contacts(self, value):
        self.contact = json.dumps(value)

    @hybrid_property
    def jwks(self):
        if self.jwks_text:
            return json.loads(self.jwks_text)
        return None

    @jwks.setter
    def jwks(self, value):
        self.jwks_text = json.dumps(value)

    @hybrid_property
    def client_metadata(self):
        """Implementation for Client Metadata in OAuth 2.0 Dynamic Client
        Registration Protocol via `Section 2`_.

        .. _`Section 2`: https://tools.ietf.org/html/rfc7591#section-2
        """
        keys = [
            'redirect_uris', 'token_endpoint_auth_method', 'grant_types',
            'response_types', 'client_name', 'client_uri', 'logo_uri',
            'scope', 'contacts', 'tos_uri', 'policy_uri', 'jwks_uri', 'jwks',
        ]
        metadata = {k: getattr(self, k) for k in keys}
        if self.i18n_metadata:
            metadata.update(json.loads(self.i18n_metadata))
        return metadata

    @client_metadata.setter
    def client_metadata(self, value):
        i18n_metadata = {}
        for k in value:
            if hasattr(self, k):
                setattr(self, k, value[k])
            elif '#' in k:
                i18n_metadata[k] = value[k]

        self.i18n_metadata = json.dumps(i18n_metadata)

    @property
    def client_info(self):
        """Implementation for Client Info in OAuth 2.0 Dynamic Client
        Registration Protocol via `Section 3.2.1`_.

        .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
        """
        return dict(
            client_id=self.client_id,
            client_secret=self.client_secret,
            client_id_issued_at=self.issued_at,
            client_secret_expires_at=self.expires_at,
        )

    def get_client_id(self):
        return self.client_id

    def get_default_redirect_uri(self):
        if self.redirect_uris:
            return self.redirect_uris[0]

    def get_allowed_scope(self, scope):
        if not scope:
            return ''
        allowed = set(self.scope.split())
        scopes = scope_to_list(scope)
        return list_to_scope([s for s in scopes if s in allowed])

    def check_redirect_uri(self, redirect_uri):
        return redirect_uri in self.redirect_uris

    def has_client_secret(self):
        return bool(self.client_secret)

    def check_client_secret(self, client_secret):
        return self.client_secret == client_secret

    def check_token_endpoint_auth_method(self, method):
        return self.token_endpoint_auth_method == method

    def check_response_type(self, response_type):
        if self.response_type:
            return response_type in self.response_types
        return False

    def check_grant_type(self, grant_type):
        if self.grant_type:
            return grant_type in self.grant_types
        return False


class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin):
    code = Column(String(120), unique=True, nullable=False)
    client_id = Column(String(48))
    redirect_uri = Column(Text, default='')
    response_type = Column(Text, default='')
    scope = Column(Text, default='')
    auth_time = Column(
        Integer, nullable=False,
        default=lambda: int(time.time())
    )

    def is_expired(self):
        return self.auth_time + 300 < time.time()

    def get_redirect_uri(self):
        return self.redirect_uri

    def get_scope(self):
        return self.scope

    def get_auth_time(self):
        return self.auth_time


class OIDCAuthorizationCodeMixin(OAuth2AuthorizationCodeMixin, OIDCCodeMixin):
    nonce = Column(Text)

    def get_nonce(self):
        return self.nonce


class OAuth2TokenMixin(TokenMixin):
    client_id = Column(String(48))
    token_type = Column(String(40))
    access_token = Column(String(255), unique=True, nullable=False)
    refresh_token = Column(String(255), index=True)
    scope = Column(Text, default='')
    revoked = Column(Boolean, default=False)
    issued_at = Column(
        Integer, nullable=False, default=lambda: int(time.time())
    )
    expires_in = Column(Integer, nullable=False, default=0)

    def get_client_id(self):
        return self.client_id

    def get_scope(self):
        return self.scope

    def get_expires_in(self):
        return self.expires_in

    def get_expires_at(self):
        return self.issued_at + self.expires_in


def create_query_client_func(session, client_model):
    """Create an ``query_client`` function that can be used in authorization
    server.

    :param session: SQLAlchemy session
    :param client_model: Client model class
    """
    def query_client(client_id):
        q = session.query(client_model)
        return q.filter_by(client_id=client_id).first()
    return query_client


def create_save_token_func(session, token_model):
    """Create an ``save_token`` function that can be used in authorization
    server.

    :param session: SQLAlchemy session
    :param token_model: Token model class
    """
    def save_token(token, request):
        if request.user:
            user_id = request.user.get_user_id()
        else:
            user_id = None
        client = request.client
        item = token_model(
            client_id=client.client_id,
            user_id=user_id,
            **token
        )
        session.add(item)
        session.commit()
    return save_token


def create_query_token_func(session, token_model):
    """Create an ``query_token`` function for revocation, introspection
    token endpoints.

    :param session: SQLAlchemy session
    :param token_model: Token model class
    """
    def query_token(token, token_type_hint, client):
        q = session.query(token_model)
        q = q.filter_by(client_id=client.client_id, revoked=False)
        if token_type_hint == 'access_token':
            return q.filter_by(access_token=token).first()
        elif token_type_hint == 'refresh_token':
            return q.filter_by(refresh_token=token).first()
        # without token_type_hint
        item = q.filter_by(access_token=token).first()
        if item:
            return item
        return q.filter_by(refresh_token=token).first()
    return query_token


def create_revocation_endpoint(session, token_model):
    """Create a revocation endpoint class with SQLAlchemy session
    and token model.

    :param session: SQLAlchemy session
    :param token_model: Token model class
    """
    from authlib.oauth2.rfc7009 import RevocationEndpoint
    query_token = create_query_token_func(session, token_model)

    class _RevocationEndpoint(RevocationEndpoint):
        def query_token(self, token, token_type_hint, client):
            return query_token(token, token_type_hint, client)

        def revoke_token(self, token):
            token.revoked = True
            session.add(token)
            session.commit()

    return _RevocationEndpoint


def create_bearer_token_validator(session, token_model):
    """Create an bearer token validator class with SQLAlchemy session
    and token model.

    :param session: SQLAlchemy session
    :param token_model: Token model class
    """
    from authlib.oauth2.rfc6750 import BearerTokenValidator

    class _BearerTokenValidator(BearerTokenValidator):
        def authenticate_token(self, token_string):
            q = session.query(token_model)
            return q.filter_by(access_token=token_string).first()

        def request_invalid(self, request):
            return False

        def token_revoked(self, token):
            return token.revoked

    return _BearerTokenValidator