# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import absolute_import, division, print_function
import abc
import datetime
import hashlib
import ipaddress
from enum import Enum
from asn1crypto.keys import PublicKeyInfo
import six
from cryptography import utils
from cryptography.hazmat.primitives import constant_time, serialization
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.x509.certificate_transparency import (
SignedCertificateTimestamp
)
from cryptography.x509.general_name import GeneralName, IPAddress, OtherName
from cryptography.x509.name import RelativeDistinguishedName
from cryptography.x509.oid import (
CRLEntryExtensionOID, ExtensionOID, ObjectIdentifier
)
def _key_identifier_from_public_key(public_key):
if isinstance(public_key, RSAPublicKey):
data = public_key.public_bytes(
serialization.Encoding.DER,
serialization.PublicFormat.PKCS1,
)
elif isinstance(public_key, EllipticCurvePublicKey):
data = public_key.public_numbers().encode_point()
else:
# This is a very slow way to do this.
serialized = public_key.public_bytes(
serialization.Encoding.DER,
serialization.PublicFormat.SubjectPublicKeyInfo
)
data = six.binary_type(PublicKeyInfo.load(serialized)['public_key'])
return hashlib.sha1(data).digest()
class DuplicateExtension(Exception):
def __init__(self, msg, oid):
super(DuplicateExtension, self).__init__(msg)
self.oid = oid
class ExtensionNotFound(Exception):
def __init__(self, msg, oid):
super(ExtensionNotFound, self).__init__(msg)
self.oid = oid
@six.add_metaclass(abc.ABCMeta)
class ExtensionType(object):
@abc.abstractproperty
def oid(self):
"""
Returns the oid associated with the given extension type.
"""
class Extensions(object):
def __init__(self, extensions):
self._extensions = extensions
def get_extension_for_oid(self, oid):
for ext in self:
if ext.oid == oid:
return ext
raise ExtensionNotFound("No {0} extension was found".format(oid), oid)
def get_extension_for_class(self, extclass):
if extclass is UnrecognizedExtension:
raise TypeError(
"UnrecognizedExtension can't be used with "
"get_extension_for_class because more than one instance of the"
" class may be present."
)
for ext in self:
if isinstance(ext.value, extclass):
return ext
raise ExtensionNotFound(
"No {0} extension was found".format(extclass), extclass.oid
)
def __iter__(self):
return iter(self._extensions)
def __len__(self):
return len(self._extensions)
def __getitem__(self, idx):
return self._extensions[idx]
def __repr__(self):
return (
"<Extensions({0})>".format(self._extensions)
)
@utils.register_interface(ExtensionType)
class CRLNumber(object):
oid = ExtensionOID.CRL_NUMBER
def __init__(self, crl_number):
if not isinstance(crl_number, six.integer_types):
raise TypeError("crl_number must be an integer")
self._crl_number = crl_number
def __eq__(self, other):
if not isinstance(other, CRLNumber):
return NotImplemented
return self.crl_number == other.crl_number
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self.crl_number)
def __repr__(self):
return "<CRLNumber({0})>".format(self.crl_number)
crl_number = utils.read_only_property("_crl_number")
@utils.register_interface(ExtensionType)
class AuthorityKeyIdentifier(object):
oid = ExtensionOID.AUTHORITY_KEY_IDENTIFIER
def __init__(self, key_identifier, authority_cert_issuer,
authority_cert_serial_number):
if (authority_cert_issuer is None) != (
authority_cert_serial_number is None
):
raise ValueError(
"authority_cert_issuer and authority_cert_serial_number "
"must both be present or both None"
)
if authority_cert_issuer is not None:
authority_cert_issuer = list(authority_cert_issuer)
if not all(
isinstance(x, GeneralName) for x in authority_cert_issuer
):
raise TypeError(
"authority_cert_issuer must be a list of GeneralName "
"objects"
)
if authority_cert_serial_number is not None and not isinstance(
authority_cert_serial_number, six.integer_types
):
raise TypeError(
"authority_cert_serial_number must be an integer"
)
self._key_identifier = key_identifier
self._authority_cert_issuer = authority_cert_issuer
self._authority_cert_serial_number = authority_cert_serial_number
@classmethod
def from_issuer_public_key(cls, public_key):
digest = _key_identifier_from_public_key(public_key)
return cls(
key_identifier=digest,
authority_cert_issuer=None,
authority_cert_serial_number=None
)
@classmethod
def from_issuer_subject_key_identifier(cls, ski):
return cls(
key_identifier=ski.value.digest,
authority_cert_issuer=None,
authority_cert_serial_number=None
)
def __repr__(self):
return (
"<AuthorityKeyIdentifier(key_identifier={0.key_identifier!r}, "
"authority_cert_issuer={0.authority_cert_issuer}, "
"authority_cert_serial_number={0.authority_cert_serial_number}"
")>".format(self)
)
def __eq__(self, other):
if not isinstance(other, AuthorityKeyIdentifier):
return NotImplemented
return (
self.key_identifier == other.key_identifier and
self.authority_cert_issuer == other.authority_cert_issuer and
self.authority_cert_serial_number ==
other.authority_cert_serial_number
)
def __ne__(self, other):
return not self == other
def __hash__(self):
if self.authority_cert_issuer is None:
aci = None
else:
aci = tuple(self.authority_cert_issuer)
return hash((
self.key_identifier, aci, self.authority_cert_serial_number
))
key_identifier = utils.read_only_property("_key_identifier")
authority_cert_issuer = utils.read_only_property("_authority_cert_issuer")
authority_cert_serial_number = utils.read_only_property(
"_authority_cert_serial_number"
)
@utils.register_interface(ExtensionType)
class SubjectKeyIdentifier(object):
oid = ExtensionOID.SUBJECT_KEY_IDENTIFIER
def __init__(self, digest):
self._digest = digest
@classmethod
def from_public_key(cls, public_key):
return cls(_key_identifier_from_public_key(public_key))
digest = utils.read_only_property("_digest")
def __repr__(self):
return "<SubjectKeyIdentifier(digest={0!r})>".format(self.digest)
def __eq__(self, other):
if not isinstance(other, SubjectKeyIdentifier):
return NotImplemented
return constant_time.bytes_eq(self.digest, other.digest)
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self.digest)
@utils.register_interface(ExtensionType)
class AuthorityInformationAccess(object):
oid = ExtensionOID.AUTHORITY_INFORMATION_ACCESS
def __init__(self, descriptions):
descriptions = list(descriptions)
if not all(isinstance(x, AccessDescription) for x in descriptions):
raise TypeError(
"Every item in the descriptions list must be an "
"AccessDescription"
)
self._descriptions = descriptions
def __iter__(self):
return iter(self._descriptions)
def __len__(self):
return len(self._descriptions)
def __repr__(self):
return "<AuthorityInformationAccess({0})>".format(self._descriptions)
def __eq__(self, other):
if not isinstance(other, AuthorityInformationAccess):
return NotImplemented
return self._descriptions == other._descriptions
def __ne__(self, other):
return not self == other
def __getitem__(self, idx):
return self._descriptions[idx]
def __hash__(self):
return hash(tuple(self._descriptions))
class AccessDescription(object):
def __init__(self, access_method, access_location):
if not isinstance(access_method, ObjectIdentifier):
raise TypeError("access_method must be an ObjectIdentifier")
if not isinstance(access_location, GeneralName):
raise TypeError("access_location must be a GeneralName")
self._access_method = access_method
self._access_location = access_location
def __repr__(self):
return (
"<AccessDescription(access_method={0.access_method}, access_locati"
"on={0.access_location})>".format(self)
)
def __eq__(self, other):
if not isinstance(other, AccessDescription):
return NotImplemented
return (
self.access_method == other.access_method and
self.access_location == other.access_location
)
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash((self.access_method, self.access_location))
access_method = utils.read_only_property("_access_method")
access_location = utils.read_only_property("_access_location")
@utils.register_interface(ExtensionType)
class BasicConstraints(object):
oid = ExtensionOID.BASIC_CONSTRAINTS
def __init__(self, ca, path_length):
if not isinstance(ca, bool):
raise TypeError("ca must be a boolean value")
if path_length is not None and not ca:
raise ValueError("path_length must be None when ca is False")
Loading ...