Repository URL to install this package:
|
Version:
0.2.20 ▾
|
py-aws-util
/
OktaGroupAuthorizer.py
|
|---|
from typing import List
import requests
import jwt
from jwt.exceptions import PyJWTError
import base64
import struct
import logging
from os import environ
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from chalice import AuthResponse
log_level = logging.getLevelName(environ['LOG_LEVEL']) if 'LOG_LEVEL' in environ else logging.DEBUG
logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%d-%m-%Y:%H:%M:%S',
level=log_level)
logger = logging.getLogger('OktaGroupAuthorizer')
class AuthorizationError(Exception):
pass
class OktaGroupAuthorizer:
@staticmethod
def default_handler_for(okta_auth):
def __auth_handler(auth_request):
try:
return okta_auth.authorize(auth_type=auth_request.auth_type,
token=auth_request.token,
method_arn=auth_request.method_arn)
except AuthorizationError:
return AuthResponse(routes=[], principal_id='user')
return __auth_handler
def __init__(self, required: List[str], base_url:str= 'https://barnhardt.okta.com', leeway:int=300):
self.required = required
self.__base_url = base_url
self.__leeway = leeway
self.__openid_configuration = self.__get_openid_configuration(base_url)
self.__public_key_cache = {}
def __get_openid_configuration(self, base_url):
oidc_discovery_url = "{}/.well-known/openid-configuration".format(base_url)
r = requests.get(oidc_discovery_url)
return r.json()
def __fetch_public_keys(self):
# FIXME: Make sure that we rate-limit outbound requests
jwks_uri = self.__openid_configuration['jwks_uri']
r = requests.get(jwks_uri)
jwks = r.json()
for key in jwks['keys']:
jwk_id = key['kid']
# Cache the public cert as a pem for use in validation
self.__public_key_cache[jwk_id] = self.__convert_to_pem(key)
def __fetch_jwk_for(self, id_token=None):
if id_token is None:
logger.warning('id_token is required')
raise Exception('Unauthorized')
jwks_uri = self.__openid_configuration['jwks_uri']
unverified_header = jwt.get_unverified_header(id_token)
key_id = None
if 'kid' in unverified_header:
key_id = unverified_header['kid']
else:
raise ValueError('The id_token header must contain a "kid"')
if key_id in self.__public_key_cache:
return self.__public_key_cache[key_id]
#only want to do this if the key doesn't exist
self.__fetch_public_keys()
if key_id in self.__public_key_cache:
return self.__public_key_cache[key_id]
else:
logger.error("Unable to fetch public key from jwks_uri")
raise RuntimeError("Unable to fetch public key from jwks_uri")
def __int_arr_to_long(self, arr):
return int(''.join(["%02x" % byte for byte in arr]), 16)
def __base64_to_long(self, data):
if isinstance(data, str):
data = data.encode("ascii")
# urlsafe_b64decode will happily convert b64encoded data
_d = base64.urlsafe_b64decode(bytes(data) + b'==')
return self.__int_arr_to_long(struct.unpack('%sB' % len(_d), _d))
def __convert_to_pem(self, key):
exponent = self.__base64_to_long(key['e'])
modulus = self.__base64_to_long(key['n'])
numbers = RSAPublicNumbers(exponent, modulus)
public_key = numbers.public_key(backend=default_backend())
pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem
# TODO: - add the ability to pass the client_id from the front end - if that makes any sense.
def validate(self, token):
dirty_claims = jwt.get_unverified_header(token)
algorithms = self.__openid_configuration['id_token_signing_alg_values_supported']
pem_key = self.__fetch_jwk_for(token)
# jwt decode will verify exp, nbf, iss, aud, and iat.
# jwt decode will throw an exception is any of the claims
# fail verification: https://pyjwt.readthedocs.io/en/latest/usage.html
claims = jwt.decode(
jwt=token,
key=pem_key,
verify=True,
algorithms=algorithms,
issuer=self.__base_url,
options={
'leeway': self.__leeway,
'verify_aud': False,
})
# if claims and 'exp' in claims:
# logger.info(f"Token expires at: {datetime.fromtimestamp(claims['exp']).isoformat()}")
return claims
def __get_user_identifier(self, claims):
if 'email' in claims:
return claims['email']
else:
return claims['sub']
def __generate_auth_response(self, principle_id, method_arn, allow=True, groups=None, name=None):
auth_response = {}
auth_response['principalId'] = principle_id
# Changed to wildcard on the policy.
# 'Resource': [method_arn]
if method_arn:
policyDocument = {
'Version': '2012-10-17',
'Statement': [
{
'Sid': 'APIAuthStatement',
'Action': 'execute-api:Invoke',
'Effect': 'Allow' if allow else 'Deny',
'Resource': '*'
}
]
}
auth_response['policyDocument'] = policyDocument
auth_response['context'] = {}
# It appears that the policy context cannot contain complex objects or arrays
# if groups:
# auth_response['context']['groups'] = groups
if name:
auth_response['context']['name'] = name
return auth_response
def authorize(self, auth_type, token, method_arn):
if token and isinstance(token, str) and token.startswith('Bearer'):
token = token.split(' ')[1]
try:
claims = self.validate(token)
if not self.required or not claims:
logger.warning(f"Access Denied. No required groups are defined on authorizer for [{method_arn}].")
return self.__generate_auth_response(
principle_id=None,
method_arn=method_arn,
allow=False
)
if not claims:
logger.error(f"No claims found for auth_type: [{auth_type}], token: [{token}], method_arn: [{method_arn}].")
#This is not a denial, just that the authentication token is bad
raise Exception('Unauthorized')
principle_id = self.__get_user_identifier(claims)
if 'groups' in claims and claims['groups']:
if any(elem in self.required for elem in claims['groups']):
# return authorized
logger.warning(
f"Access Granted. Principle:[{principle_id}] has the required groups: [{self.required}] to access: [{method_arn}].")
return self.__generate_auth_response(
principle_id=principle_id,
method_arn=method_arn,
allow=True,
groups=claims['groups'],
name=claims['name']
)
else:
# return not authorized not in group
logger.warning(f"Access Denied. Principle:[{principle_id}] is not in the required groups: [{self.required}] to access: [{method_arn}].")
return self.__generate_auth_response(
principle_id=None,
method_arn=method_arn,
allow=False
)
else:
# return not authorized, no groups defined for user
logger.warning(f"Access Denied. Principle:[{principle_id}] has no groups associated with their user.")
return self.__generate_auth_response(
principle_id=None,
method_arn=method_arn,
allow=False
)
except PyJWTError as jwte:
logger.error(f"Unable to validate token. {str(jwte)}.")
raise AuthorizationError(jwte)