Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
py-aws-util / OktaGroupAuthorizer.py
Size: Mime:
from typing import List
import requests
import jwt
from jwt.exceptions import PyJWTError
import base64
import struct
import logging
from os import environ
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from chalice import AuthResponse


log_level = logging.getLevelName(environ['LOG_LEVEL']) if 'LOG_LEVEL' in environ else logging.DEBUG
logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
                    datefmt='%d-%m-%Y:%H:%M:%S',
                    level=log_level)
logger = logging.getLogger('OktaGroupAuthorizer')


class AuthorizationError(Exception):
    pass


class OktaGroupAuthorizer:
    @staticmethod
    def default_handler_for(okta_auth):
        def __auth_handler(auth_request):
            try:
                return okta_auth.authorize(auth_type=auth_request.auth_type,
                                           token=auth_request.token,
                                           method_arn=auth_request.method_arn)
            except AuthorizationError:
                return AuthResponse(routes=[], principal_id='user')

        return __auth_handler

    def __init__(self, required: List[str], base_url:str= 'https://barnhardt.okta.com', leeway:int=300):
        self.required = required
        self.__base_url = base_url
        self.__leeway = leeway
        self.__openid_configuration = self.__get_openid_configuration(base_url)
        self.__public_key_cache = {}

    def __get_openid_configuration(self, base_url):
        oidc_discovery_url = "{}/.well-known/openid-configuration".format(base_url)
        r = requests.get(oidc_discovery_url)
        return r.json()

    def __fetch_public_keys(self):
        # FIXME: Make sure that we rate-limit outbound requests
        jwks_uri = self.__openid_configuration['jwks_uri']
        r = requests.get(jwks_uri)
        jwks = r.json()
        for key in jwks['keys']:
            jwk_id = key['kid']
            # Cache the public cert as a pem for use in validation
            self.__public_key_cache[jwk_id] = self.__convert_to_pem(key)


    def __fetch_jwk_for(self, id_token=None):
        if id_token is None:
            logger.warning('id_token is required')
            raise Exception('Unauthorized')

        jwks_uri = self.__openid_configuration['jwks_uri']

        unverified_header = jwt.get_unverified_header(id_token)
        key_id = None
        if 'kid' in unverified_header:
            key_id = unverified_header['kid']
        else:
            raise ValueError('The id_token header must contain a "kid"')
        if key_id in self.__public_key_cache:
            return self.__public_key_cache[key_id]

        #only want to do this if the key doesn't exist
        self.__fetch_public_keys()
        if key_id in self.__public_key_cache:
            return self.__public_key_cache[key_id]
        else:
            logger.error("Unable to fetch public key from jwks_uri")
            raise RuntimeError("Unable to fetch public key from jwks_uri")

    def __int_arr_to_long(self, arr):
        return int(''.join(["%02x" % byte for byte in arr]), 16)

    def __base64_to_long(self, data):
        if isinstance(data, str):
            data = data.encode("ascii")

        # urlsafe_b64decode will happily convert b64encoded data
        _d = base64.urlsafe_b64decode(bytes(data) + b'==')
        return self.__int_arr_to_long(struct.unpack('%sB' % len(_d), _d))

    def __convert_to_pem(self, key):
        exponent = self.__base64_to_long(key['e'])
        modulus = self.__base64_to_long(key['n'])
        numbers = RSAPublicNumbers(exponent, modulus)
        public_key = numbers.public_key(backend=default_backend())
        pem = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        return pem

# TODO: - add the ability to pass the client_id from the front end - if that makes any sense.
    def validate(self, token):
        dirty_claims = jwt.get_unverified_header(token)
        algorithms = self.__openid_configuration['id_token_signing_alg_values_supported']
        pem_key = self.__fetch_jwk_for(token)

        # jwt decode will verify exp, nbf, iss, aud, and iat.
        # jwt decode will throw an exception is any of the claims
        # fail verification: https://pyjwt.readthedocs.io/en/latest/usage.html
        claims = jwt.decode(
            jwt=token,
            key=pem_key,
            verify=True,
            algorithms=algorithms,
            issuer=self.__base_url,
            options={
                'leeway': self.__leeway,
                'verify_aud': False,
            })
        # if claims and 'exp' in claims:
        #     logger.info(f"Token expires at: {datetime.fromtimestamp(claims['exp']).isoformat()}")

        return claims

    def __get_user_identifier(self, claims):
        if 'email' in claims:
            return claims['email']
        else:
            return claims['sub']

    def __generate_auth_response(self, principle_id, method_arn, allow=True, groups=None, name=None):
        auth_response = {}
        auth_response['principalId'] = principle_id

        # Changed to wildcard on the policy.
        # 'Resource': [method_arn]
        if method_arn:
            policyDocument = {
                'Version': '2012-10-17',
                'Statement': [
                    {
                        'Sid': 'APIAuthStatement',
                        'Action': 'execute-api:Invoke',
                        'Effect': 'Allow' if allow else 'Deny',
                        'Resource': '*'
                    }
                ]
            }

            auth_response['policyDocument'] = policyDocument
        auth_response['context'] = {}
        # It appears that the policy context cannot contain complex objects or arrays
        # if groups:
        #     auth_response['context']['groups'] = groups
        if name:
            auth_response['context']['name'] = name
        return auth_response

    def authorize(self, auth_type, token, method_arn):
        if token and isinstance(token, str) and token.startswith('Bearer'):
            token = token.split(' ')[1]
        try:
            claims = self.validate(token)
            if not self.required or not claims:
                logger.warning(f"Access Denied. No required groups are defined on authorizer for [{method_arn}].")
                return self.__generate_auth_response(
                    principle_id=None,
                    method_arn=method_arn,
                    allow=False
                )
            if not claims:
                logger.error(f"No claims found for auth_type: [{auth_type}], token: [{token}], method_arn: [{method_arn}].")
                #This is not a denial, just that the authentication token is bad
                raise Exception('Unauthorized')
            principle_id = self.__get_user_identifier(claims)
            if 'groups' in claims and claims['groups']:
                if any(elem in self.required for elem in claims['groups']):
                    # return authorized
                    logger.warning(
                        f"Access Granted. Principle:[{principle_id}] has the required groups: [{self.required}] to access: [{method_arn}].")
                    return self.__generate_auth_response(
                        principle_id=principle_id,
                        method_arn=method_arn,
                        allow=True,
                        groups=claims['groups'],
                        name=claims['name']
                    )
                else:
                    # return not authorized not in group
                    logger.warning(f"Access Denied. Principle:[{principle_id}] is not in the required groups: [{self.required}] to access: [{method_arn}].")
                    return self.__generate_auth_response(
                        principle_id=None,
                        method_arn=method_arn,
                        allow=False
                    )
            else:
                # return not authorized, no groups defined for user
                logger.warning(f"Access Denied. Principle:[{principle_id}] has no groups associated with their user.")
                return self.__generate_auth_response(
                    principle_id=None,
                    method_arn=method_arn,
                    allow=False
                )
        except PyJWTError as jwte:
            logger.error(f"Unable to validate token. {str(jwte)}.")
            raise AuthorizationError(jwte)