Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
cryptography / src / backend / rsa.rs
Size: Mime:
// This file is dual licensed under the terms of the Apache License, Version
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
// for complete details.

use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use crate::backend::{hashes, utils};
use crate::buf::CffiBuf;
use crate::error::{CryptographyError, CryptographyResult};
use crate::{exceptions, types};
use pyo3::types::PyAnyMethods;

#[pyo3::pyclass(
    frozen,
    module = "cryptography.hazmat.bindings._rust.openssl.rsa",
    name = "RSAPrivateKey"
)]
pub(crate) struct RsaPrivateKey {
    pkey: openssl::pkey::PKey<openssl::pkey::Private>,
}

#[pyo3::pyclass(
    frozen,
    module = "cryptography.hazmat.bindings._rust.openssl.rsa",
    name = "RSAPublicKey"
)]
pub(crate) struct RsaPublicKey {
    pkey: openssl::pkey::PKey<openssl::pkey::Public>,
}

fn check_rsa_private_key(
    rsa: &openssl::rsa::Rsa<openssl::pkey::Private>,
) -> CryptographyResult<()> {
    if !rsa.check_key().unwrap_or(false) || rsa.p().unwrap().is_even() || rsa.q().unwrap().is_even()
    {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("Invalid private key"),
        ));
    }
    Ok(())
}

pub(crate) fn private_key_from_pkey(
    pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
    unsafe_skip_rsa_key_validation: bool,
) -> CryptographyResult<RsaPrivateKey> {
    if !unsafe_skip_rsa_key_validation {
        check_rsa_private_key(&pkey.rsa().unwrap())?;
    }
    Ok(RsaPrivateKey {
        pkey: pkey.to_owned(),
    })
}

pub(crate) fn public_key_from_pkey(
    pkey: &openssl::pkey::PKeyRef<openssl::pkey::Public>,
) -> RsaPublicKey {
    RsaPublicKey {
        pkey: pkey.to_owned(),
    }
}

#[pyo3::pyfunction]
fn generate_private_key(public_exponent: u32, key_size: u32) -> CryptographyResult<RsaPrivateKey> {
    let e = openssl::bn::BigNum::from_u32(public_exponent)?;
    let rsa = openssl::rsa::Rsa::generate_with_e(key_size, &e)?;
    let pkey = openssl::pkey::PKey::from_rsa(rsa)?;
    Ok(RsaPrivateKey { pkey })
}

fn oaep_hash_supported(md: &openssl::hash::MessageDigest) -> bool {
    (!cryptography_openssl::fips::is_enabled() && md == &openssl::hash::MessageDigest::sha1())
        || md == &openssl::hash::MessageDigest::sha224()
        || md == &openssl::hash::MessageDigest::sha256()
        || md == &openssl::hash::MessageDigest::sha384()
        || md == &openssl::hash::MessageDigest::sha512()
}

fn setup_encryption_ctx(
    py: pyo3::Python<'_>,
    ctx: &mut openssl::pkey_ctx::PkeyCtx<impl openssl::pkey::HasPublic>,
    padding: &pyo3::Bound<'_, pyo3::PyAny>,
) -> CryptographyResult<()> {
    if !padding.is_instance(&types::ASYMMETRIC_PADDING.get(py)?)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyTypeError::new_err(
                "Padding must be an instance of AsymmetricPadding.",
            ),
        ));
    }

    let padding_enum = if padding.is_instance(&types::PKCS1V15.get(py)?)? {
        openssl::rsa::Padding::PKCS1
    } else if padding.is_instance(&types::OAEP.get(py)?)? {
        if !padding
            .getattr(pyo3::intern!(py, "_mgf"))?
            .is_instance(&types::MGF1.get(py)?)?
        {
            return Err(CryptographyError::from(
                exceptions::UnsupportedAlgorithm::new_err((
                    "Only MGF1 is supported.",
                    exceptions::Reasons::UNSUPPORTED_MGF,
                )),
            ));
        }

        openssl::rsa::Padding::PKCS1_OAEP
    } else {
        return Err(CryptographyError::from(
            exceptions::UnsupportedAlgorithm::new_err((
                format!(
                    "{} is not supported by this backend.",
                    padding.getattr(pyo3::intern!(py, "name"))?
                ),
                exceptions::Reasons::UNSUPPORTED_PADDING,
            )),
        ));
    };

    ctx.set_rsa_padding(padding_enum)?;

    if padding_enum == openssl::rsa::Padding::PKCS1_OAEP {
        let mgf1_md = hashes::message_digest_from_algorithm(
            py,
            &padding
                .getattr(pyo3::intern!(py, "_mgf"))?
                .getattr(pyo3::intern!(py, "_algorithm"))?,
        )?;
        let oaep_md = hashes::message_digest_from_algorithm(
            py,
            &padding.getattr(pyo3::intern!(py, "_algorithm"))?,
        )?;

        if !oaep_hash_supported(&mgf1_md) || !oaep_hash_supported(&oaep_md) {
            return Err(CryptographyError::from(
                exceptions::UnsupportedAlgorithm::new_err((
                    "This combination of padding and hash algorithm is not supported",
                    exceptions::Reasons::UNSUPPORTED_PADDING,
                )),
            ));
        }

        ctx.set_rsa_mgf1_md(openssl::md::Md::from_nid(mgf1_md.type_()).unwrap())?;
        ctx.set_rsa_oaep_md(openssl::md::Md::from_nid(oaep_md.type_()).unwrap())?;

        if let Some(label) = padding
            .getattr(pyo3::intern!(py, "_label"))?
            .extract::<Option<pyo3::pybacked::PyBackedBytes>>()?
        {
            if !label.is_empty() {
                ctx.set_rsa_oaep_label(&label)?;
            }
        }
    }

    Ok(())
}

fn setup_signature_ctx(
    py: pyo3::Python<'_>,
    ctx: &mut openssl::pkey_ctx::PkeyCtx<impl openssl::pkey::HasPublic>,
    padding: &pyo3::Bound<'_, pyo3::PyAny>,
    algorithm: &pyo3::Bound<'_, pyo3::PyAny>,
    key_size: usize,
    is_signing: bool,
) -> CryptographyResult<()> {
    if !padding.is_instance(&types::ASYMMETRIC_PADDING.get(py)?)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyTypeError::new_err(
                "Padding must be an instance of AsymmetricPadding.",
            ),
        ));
    }

    let padding_enum = if padding.is_instance(&types::PKCS1V15.get(py)?)? {
        openssl::rsa::Padding::PKCS1
    } else if padding.is_instance(&types::PSS.get(py)?)? {
        if !padding
            .getattr(pyo3::intern!(py, "_mgf"))?
            .is_instance(&types::MGF1.get(py)?)?
        {
            return Err(CryptographyError::from(
                exceptions::UnsupportedAlgorithm::new_err((
                    "Only MGF1 is supported.",
                    exceptions::Reasons::UNSUPPORTED_MGF,
                )),
            ));
        }

        // PSS padding requires a hash algorithm
        if !algorithm.is_instance(&types::HASH_ALGORITHM.get(py)?)? {
            return Err(CryptographyError::from(
                pyo3::exceptions::PyTypeError::new_err(
                    "Expected instance of hashes.HashAlgorithm.",
                ),
            ));
        }

        if algorithm
            .getattr(pyo3::intern!(py, "digest_size"))?
            .extract::<usize>()?
            + 2
            > key_size
        {
            return Err(CryptographyError::from(
                pyo3::exceptions::PyValueError::new_err(
                    "Digest too large for key size. Use a larger key or different digest.",
                ),
            ));
        }

        openssl::rsa::Padding::PKCS1_PSS
    } else {
        return Err(CryptographyError::from(
            exceptions::UnsupportedAlgorithm::new_err((
                format!(
                    "{} is not supported by this backend.",
                    padding.getattr(pyo3::intern!(py, "name"))?
                ),
                exceptions::Reasons::UNSUPPORTED_PADDING,
            )),
        ));
    };

    if !algorithm.is_none() {
        let md = hashes::message_digest_from_algorithm(py, algorithm)?;
        ctx.set_signature_md(openssl::md::Md::from_nid(md.type_()).unwrap())
            .or_else(|_| {
                Err(CryptographyError::from(
                    exceptions::UnsupportedAlgorithm::new_err((
                        format!(
                            "{} is not supported by this backend for RSA signing.",
                            algorithm.getattr(pyo3::intern!(py, "name"))?
                        ),
                        exceptions::Reasons::UNSUPPORTED_HASH,
                    )),
                ))
            })?;
    }
    ctx.set_rsa_padding(padding_enum).or_else(|_| {
        Err(exceptions::UnsupportedAlgorithm::new_err((
            format!(
                "{} is not supported for the RSA signature operation",
                padding.getattr(pyo3::intern!(py, "name"))?
            ),
            exceptions::Reasons::UNSUPPORTED_PADDING,
        )))
    })?;

    if padding_enum == openssl::rsa::Padding::PKCS1_PSS {
        let salt = padding.getattr(pyo3::intern!(py, "_salt_length"))?;
        if salt.is_instance(&types::PADDING_MAX_LENGTH.get(py)?)? {
            ctx.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::MAXIMUM_LENGTH)?;
        } else if salt.is_instance(&types::PADDING_DIGEST_LENGTH.get(py)?)? {
            ctx.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::DIGEST_LENGTH)?;
        } else if salt.is_instance(&types::PADDING_AUTO.get(py)?)? {
            if is_signing {
                return Err(CryptographyError::from(
                    pyo3::exceptions::PyValueError::new_err(
                        "PSS salt length can only be set to Auto when verifying",
                    ),
                ));
            }
        } else {
            ctx.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::custom(salt.extract::<i32>()?))?;
        };

        let mgf1_md = hashes::message_digest_from_algorithm(
            py,
            &padding
                .getattr(pyo3::intern!(py, "_mgf"))?
                .getattr(pyo3::intern!(py, "_algorithm"))?,
        )?;
        ctx.set_rsa_mgf1_md(openssl::md::Md::from_nid(mgf1_md.type_()).unwrap())?;
    }

    Ok(())
}

#[pyo3::pymethods]
impl RsaPrivateKey {
    fn sign<'p>(
        &self,
        py: pyo3::Python<'p>,
        data: CffiBuf<'_>,
        padding: &pyo3::Bound<'p, pyo3::PyAny>,
        algorithm: &pyo3::Bound<'p, pyo3::PyAny>,
    ) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyAny>> {
        let (data, algorithm) =
            utils::calculate_digest_and_algorithm(py, data.as_bytes(), algorithm)?;

        let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?;
        ctx.sign_init().map_err(|_| {
            pyo3::exceptions::PyValueError::new_err("Unable to sign/verify with this key")
        })?;
        setup_signature_ctx(py, &mut ctx, padding, &algorithm, self.pkey.size(), true)?;

        let length = ctx.sign(data.as_bytes(), None)?;
        Ok(pyo3::types::PyBytes::new_bound_with(py, length, |b| {
            let length = ctx.sign(data.as_bytes(), Some(b)).map_err(|_| {
                pyo3::exceptions::PyValueError::new_err(
                    "Digest or salt length too long for key size. Use a larger key or shorter salt length if you are specifying a PSS salt",
                )
            })?;
            assert_eq!(length, b.len());
            Ok(())
        })?.into_any())
    }

    fn decrypt<'p>(
        &self,
        py: pyo3::Python<'p>,
        ciphertext: &[u8],
        padding: &pyo3::Bound<'p, pyo3::PyAny>,
    ) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
        let key_size_bytes =
            usize::try_from((self.pkey.rsa().unwrap().n().num_bits() + 7) / 8).unwrap();
        if key_size_bytes != ciphertext.len() {
            return Err(CryptographyError::from(
                pyo3::exceptions::PyValueError::new_err(
                    "Ciphertext length must be equal to key size.",
                ),
            ));
        }

        let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?;
        ctx.decrypt_init()?;

        setup_encryption_ctx(py, &mut ctx, padding)?;

        // Everything from this line onwards is written with the goal of being
        // as constant-time as is practical given the constraints of
        // rust-openssl and our API. See Bleichenbacher's '98 attack on RSA,
        // and its many many variants. As such, you should not attempt to
        // change this (particularly to "clean it up") without understanding
        // why it was written this way (see Chesterton's Fence), and without
        // measuring to verify you have not introduced observable time
        // differences.
        //
        // Once OpenSSL 3.2.0 is out, this can be simplified, as OpenSSL will
        // have its own mitigations for Bleichenbacher's attack.
        let length = ctx.decrypt(ciphertext, None).unwrap();
        let mut plaintext = vec![0; length];
        let result = ctx.decrypt(ciphertext, Some(&mut plaintext));

        let py_result =
            pyo3::types::PyBytes::new_bound(py, &plaintext[..*result.as_ref().unwrap_or(&length)]);
        if result.is_err() {
            return Err(CryptographyError::from(
                pyo3::exceptions::PyValueError::new_err("Decryption failed"),
            ));
        }
        Ok(py_result)
    }

    #[getter]
    fn key_size(&self) -> i32 {
        self.pkey.rsa().unwrap().n().num_bits()
    }

    fn public_key(&self) -> CryptographyResult<RsaPublicKey> {
        let priv_rsa = self.pkey.rsa().unwrap();
        let rsa = openssl::rsa::Rsa::from_public_components(
            priv_rsa.n().to_owned()?,
            priv_rsa.e().to_owned()?,
        )
        .unwrap();
        let pkey = openssl::pkey::PKey::from_rsa(rsa)?;
        Ok(RsaPublicKey { pkey })
    }

    fn private_numbers(&self, py: pyo3::Python<'_>) -> CryptographyResult<RsaPrivateNumbers> {
        let rsa = self.pkey.rsa().unwrap();

        let py_p = utils::bn_to_py_int(py, rsa.p().unwrap())?;
        let py_q = utils::bn_to_py_int(py, rsa.q().unwrap())?;
        let py_d = utils::bn_to_py_int(py, rsa.d())?;
        let py_dmp1 = utils::bn_to_py_int(py, rsa.dmp1().unwrap())?;
        let py_dmq1 = utils::bn_to_py_int(py, rsa.dmq1().unwrap())?;
        let py_iqmp = utils::bn_to_py_int(py, rsa.iqmp().unwrap())?;
        let py_e = utils::bn_to_py_int(py, rsa.e())?;
        let py_n = utils::bn_to_py_int(py, rsa.n())?;

        let public_numbers = RsaPublicNumbers {
            e: py_e.extract()?,
            n: py_n.extract()?,
        };
        Ok(RsaPrivateNumbers {
            p: py_p.extract()?,
            q: py_q.extract()?,
            d: py_d.extract()?,
            dmp1: py_dmp1.extract()?,
            dmq1: py_dmq1.extract()?,
            iqmp: py_iqmp.extract()?,
            public_numbers: pyo3::Py::new(py, public_numbers)?,
        })
    }

    fn private_bytes<'p>(
        slf: &pyo3::Bound<'p, Self>,
        py: pyo3::Python<'p>,
        encoding: &pyo3::Bound<'p, pyo3::PyAny>,
        format: &pyo3::Bound<'p, pyo3::PyAny>,
        encryption_algorithm: &pyo3::Bound<'p, pyo3::PyAny>,
    ) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
        utils::pkey_private_bytes(
            py,
            slf,
            &slf.borrow().pkey,
            encoding,
            format,
            encryption_algorithm,
            true,
            false,
        )
    }
}

#[pyo3::pymethods]
impl RsaPublicKey {
    fn verify(
        &self,
        py: pyo3::Python<'_>,
        signature: CffiBuf<'_>,
        data: CffiBuf<'_>,
        padding: &pyo3::Bound<'_, pyo3::PyAny>,
        algorithm: &pyo3::Bound<'_, pyo3::PyAny>,
    ) -> CryptographyResult<()> {
        let (data, algorithm) =
            utils::calculate_digest_and_algorithm(py, data.as_bytes(), algorithm)?;

        let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?;
        ctx.verify_init()?;
        setup_signature_ctx(py, &mut ctx, padding, &algorithm, self.pkey.size(), false)?;

        let valid = ctx
            .verify(data.as_bytes(), signature.as_bytes())
            .unwrap_or(false);
        if !valid {
            return Err(CryptographyError::from(
                exceptions::InvalidSignature::new_err(()),
            ));
        }

        Ok(())
    }

    fn encrypt<'p>(
        &self,
        py: pyo3::Python<'p>,
        plaintext: &[u8],
        padding: &pyo3::Bound<'p, pyo3::PyAny>,
    ) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
        let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?;
        ctx.encrypt_init()?;

        setup_encryption_ctx(py, &mut ctx, padding)?;

        let length = ctx.encrypt(plaintext, None)?;
        Ok(pyo3::types::PyBytes::new_bound_with(py, length, |b| {
            let length = ctx
                .encrypt(plaintext, Some(b))
                .map_err(|_| pyo3::exceptions::PyValueError::new_err("Encryption failed"))?;
            assert_eq!(length, b.len());
            Ok(())
        })?)
    }

    fn recover_data_from_signature<'p>(
        &self,
        py: pyo3::Python<'p>,
        signature: &[u8],
        padding: &pyo3::Bound<'_, pyo3::PyAny>,
        algorithm: &pyo3::Bound<'_, pyo3::PyAny>,
    ) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
        if algorithm.is_instance(&types::PREHASHED.get(py)?)? {
            return Err(CryptographyError::from(
                pyo3::exceptions::PyTypeError::new_err(
                    "Prehashed is only supported in the sign and verify methods. It cannot be used with recover_data_from_signature.",
                ),
            ));
        }

        let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?;
        ctx.verify_recover_init()?;
        setup_signature_ctx(py, &mut ctx, padding, algorithm, self.pkey.size(), false)?;

        let length = ctx.verify_recover(signature, None)?;
        let mut buf = vec![0u8; length];
        let length = ctx
            .verify_recover(signature, Some(&mut buf))
            .map_err(|_| exceptions::InvalidSignature::new_err(()))?;

        Ok(pyo3::types::PyBytes::new_bound(py, &buf[..length]))
    }

    #[getter]
    fn key_size(&self) -> i32 {
        self.pkey.rsa().unwrap().n().num_bits()
    }

    fn public_numbers(&self, py: pyo3::Python<'_>) -> CryptographyResult<RsaPublicNumbers> {
        let rsa = self.pkey.rsa().unwrap();

        let py_e = utils::bn_to_py_int(py, rsa.e())?;
        let py_n = utils::bn_to_py_int(py, rsa.n())?;

        Ok(RsaPublicNumbers {
            e: py_e.extract()?,
            n: py_n.extract()?,
        })
    }

    fn public_bytes<'p>(
        slf: &pyo3::Bound<'p, Self>,
        py: pyo3::Python<'p>,
        encoding: &pyo3::Bound<'p, pyo3::PyAny>,
        format: &pyo3::Bound<'p, pyo3::PyAny>,
    ) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
        utils::pkey_public_bytes(py, slf, &slf.borrow().pkey, encoding, format, true, false)
    }

    fn __eq__(&self, other: pyo3::PyRef<'_, Self>) -> bool {
        self.pkey.public_eq(&other.pkey)
    }

    fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
        slf
    }
}

#[pyo3::pyclass(
    frozen,
    module = "cryptography.hazmat.primitives.asymmetric.rsa",
    name = "RSAPrivateNumbers"
)]
struct RsaPrivateNumbers {
    #[pyo3(get)]
    p: pyo3::Py<pyo3::types::PyLong>,
    #[pyo3(get)]
    q: pyo3::Py<pyo3::types::PyLong>,
    #[pyo3(get)]
    d: pyo3::Py<pyo3::types::PyLong>,
    #[pyo3(get)]
    dmp1: pyo3::Py<pyo3::types::PyLong>,
    #[pyo3(get)]
    dmq1: pyo3::Py<pyo3::types::PyLong>,
    #[pyo3(get)]
    iqmp: pyo3::Py<pyo3::types::PyLong>,
    #[pyo3(get)]
    public_numbers: pyo3::Py<RsaPublicNumbers>,
}

#[pyo3::pyclass(
    frozen,
    module = "cryptography.hazmat.primitives.asymmetric.rsa",
    name = "RSAPublicNumbers"
)]
struct RsaPublicNumbers {
    #[pyo3(get)]
    e: pyo3::Py<pyo3::types::PyLong>,
    #[pyo3(get)]
    n: pyo3::Py<pyo3::types::PyLong>,
}

#[allow(clippy::too_many_arguments)]
fn check_private_key_components(
    p: &pyo3::Bound<'_, pyo3::types::PyLong>,
    q: &pyo3::Bound<'_, pyo3::types::PyLong>,
    private_exponent: &pyo3::Bound<'_, pyo3::types::PyLong>,
    dmp1: &pyo3::Bound<'_, pyo3::types::PyLong>,
    dmq1: &pyo3::Bound<'_, pyo3::types::PyLong>,
    iqmp: &pyo3::Bound<'_, pyo3::types::PyLong>,
    public_exponent: &pyo3::Bound<'_, pyo3::types::PyLong>,
    modulus: &pyo3::Bound<'_, pyo3::types::PyLong>,
) -> CryptographyResult<()> {
    if modulus.lt(3)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("modulus must be >= 3."),
        ));
    }

    if p.ge(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("p must be < modulus."),
        ));
    }

    if q.ge(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("q must be < modulus."),
        ));
    }

    if dmp1.ge(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("dmp1 must be < modulus."),
        ));
    }

    if dmq1.ge(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("dmq1 must be < modulus."),
        ));
    }

    if iqmp.ge(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("iqmp must be < modulus."),
        ));
    }

    if private_exponent.ge(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("private_exponent must be < modulus."),
        ));
    }

    if public_exponent.lt(3)? || public_exponent.ge(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("public_exponent must be >= 3 and < modulus."),
        ));
    }

    if public_exponent.bitand(1)?.eq(0)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("public_exponent must be odd."),
        ));
    }

    if dmp1.bitand(1)?.eq(0)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("dmp1 must be odd."),
        ));
    }

    if dmq1.bitand(1)?.eq(0)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("dmq1 must be odd."),
        ));
    }

    if p.mul(q)?.ne(modulus)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("p*q must equal modulus."),
        ));
    }

    Ok(())
}

#[pyo3::pymethods]
impl RsaPrivateNumbers {
    #[new]
    fn new(
        p: pyo3::Py<pyo3::types::PyLong>,
        q: pyo3::Py<pyo3::types::PyLong>,
        d: pyo3::Py<pyo3::types::PyLong>,
        dmp1: pyo3::Py<pyo3::types::PyLong>,
        dmq1: pyo3::Py<pyo3::types::PyLong>,
        iqmp: pyo3::Py<pyo3::types::PyLong>,
        public_numbers: pyo3::Py<RsaPublicNumbers>,
    ) -> RsaPrivateNumbers {
        Self {
            p,
            q,
            d,
            dmp1,
            dmq1,
            iqmp,
            public_numbers,
        }
    }

    #[pyo3(signature = (backend = None, *, unsafe_skip_rsa_key_validation = false))]
    fn private_key(
        &self,
        py: pyo3::Python<'_>,
        backend: Option<&pyo3::Bound<'_, pyo3::PyAny>>,
        unsafe_skip_rsa_key_validation: bool,
    ) -> CryptographyResult<RsaPrivateKey> {
        let _ = backend;

        check_private_key_components(
            self.p.bind(py),
            self.q.bind(py),
            self.d.bind(py),
            self.dmp1.bind(py),
            self.dmq1.bind(py),
            self.iqmp.bind(py),
            self.public_numbers.get().e.bind(py),
            self.public_numbers.get().n.bind(py),
        )?;
        let public_numbers = self.public_numbers.get();
        let rsa = openssl::rsa::Rsa::from_private_components(
            utils::py_int_to_bn(py, public_numbers.n.bind(py))?,
            utils::py_int_to_bn(py, public_numbers.e.bind(py))?,
            utils::py_int_to_bn(py, self.d.bind(py))?,
            utils::py_int_to_bn(py, self.p.bind(py))?,
            utils::py_int_to_bn(py, self.q.bind(py))?,
            utils::py_int_to_bn(py, self.dmp1.bind(py))?,
            utils::py_int_to_bn(py, self.dmq1.bind(py))?,
            utils::py_int_to_bn(py, self.iqmp.bind(py))?,
        )
        .unwrap();
        if !unsafe_skip_rsa_key_validation {
            check_rsa_private_key(&rsa)?;
        }
        let pkey = openssl::pkey::PKey::from_rsa(rsa)?;
        Ok(RsaPrivateKey { pkey })
    }

    fn __eq__(
        &self,
        py: pyo3::Python<'_>,
        other: pyo3::PyRef<'_, Self>,
    ) -> CryptographyResult<bool> {
        Ok(self.p.bind(py).eq(other.p.bind(py))?
            && self.q.bind(py).eq(other.q.bind(py))?
            && self.d.bind(py).eq(other.d.bind(py))?
            && self.dmp1.bind(py).eq(other.dmp1.bind(py))?
            && self.dmq1.bind(py).eq(other.dmq1.bind(py))?
            && self.iqmp.bind(py).eq(other.iqmp.bind(py))?
            && self
                .public_numbers
                .bind(py)
                .eq(other.public_numbers.bind(py))?)
    }

    fn __hash__(&self, py: pyo3::Python<'_>) -> CryptographyResult<u64> {
        let mut hasher = DefaultHasher::new();
        self.p.bind(py).hash()?.hash(&mut hasher);
        self.q.bind(py).hash()?.hash(&mut hasher);
        self.d.bind(py).hash()?.hash(&mut hasher);
        self.dmp1.bind(py).hash()?.hash(&mut hasher);
        self.dmq1.bind(py).hash()?.hash(&mut hasher);
        self.iqmp.bind(py).hash()?.hash(&mut hasher);
        self.public_numbers.bind(py).hash()?.hash(&mut hasher);
        Ok(hasher.finish())
    }
}

fn check_public_key_components(
    e: &pyo3::Bound<'_, pyo3::types::PyLong>,
    n: &pyo3::Bound<'_, pyo3::types::PyLong>,
) -> CryptographyResult<()> {
    if n.lt(3)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("n must be >= 3."),
        ));
    }

    if e.lt(3)? || e.ge(n)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("e must be >= 3 and < n."),
        ));
    }

    if e.bitand(1)?.eq(0)? {
        return Err(CryptographyError::from(
            pyo3::exceptions::PyValueError::new_err("e must be odd."),
        ));
    }

    Ok(())
}

#[pyo3::pymethods]
impl RsaPublicNumbers {
    #[new]
    fn new(e: pyo3::Py<pyo3::types::PyLong>, n: pyo3::Py<pyo3::types::PyLong>) -> RsaPublicNumbers {
        RsaPublicNumbers { e, n }
    }

    #[pyo3(signature = (backend=None))]
    fn public_key(
        &self,
        py: pyo3::Python<'_>,
        backend: Option<&pyo3::Bound<'_, pyo3::PyAny>>,
    ) -> CryptographyResult<RsaPublicKey> {
        let _ = backend;

        check_public_key_components(self.e.bind(py), self.n.bind(py))?;

        let rsa = openssl::rsa::Rsa::from_public_components(
            utils::py_int_to_bn(py, self.n.bind(py))?,
            utils::py_int_to_bn(py, self.e.bind(py))?,
        )
        .unwrap();
        let pkey = openssl::pkey::PKey::from_rsa(rsa)?;
        Ok(RsaPublicKey { pkey })
    }

    fn __eq__(
        &self,
        py: pyo3::Python<'_>,
        other: pyo3::PyRef<'_, Self>,
    ) -> CryptographyResult<bool> {
        Ok(self.e.bind(py).eq(other.e.bind(py))? && self.n.bind(py).eq(other.n.bind(py))?)
    }

    fn __hash__(&self, py: pyo3::Python<'_>) -> CryptographyResult<u64> {
        let mut hasher = DefaultHasher::new();
        self.e.bind(py).hash()?.hash(&mut hasher);
        self.n.bind(py).hash()?.hash(&mut hasher);
        Ok(hasher.finish())
    }

    fn __repr__(&self, py: pyo3::Python<'_>) -> pyo3::PyResult<String> {
        let e = self.e.bind(py);
        let n = self.n.bind(py);
        Ok(format!("<RSAPublicNumbers(e={e}, n={n})>"))
    }
}

#[pyo3::pymodule]
pub(crate) mod rsa {
    #[pymodule_export]
    use super::{
        generate_private_key, RsaPrivateKey, RsaPrivateNumbers, RsaPublicKey, RsaPublicNumbers,
    };
}