Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

/ tests / test_dataset_encryption.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import base64
from datetime import timedelta
import numpy as np
import pyarrow.fs as fs
import pyarrow as pa

import pytest

encryption_unavailable = False

try:
    import pyarrow.parquet as pq
    import pyarrow.dataset as ds
except ImportError:
    pq = None
    ds = None

try:
    from pyarrow.tests.parquet.encryption import InMemoryKmsClient
    import pyarrow.parquet.encryption as pe
except ImportError:
    encryption_unavailable = True


# Marks all of the tests in this module
pytestmark = pytest.mark.dataset


FOOTER_KEY = b"0123456789112345"
FOOTER_KEY_NAME = "footer_key"
COL_KEY = b"1234567890123450"
COL_KEY_NAME = "col_key"


def create_sample_table():
    return pa.table(
        {
            "year": [2020, 2022, 2021, 2022, 2019, 2021],
            "n_legs": [2, 2, 4, 4, 5, 100],
            "animal": [
                "Flamingo",
                "Parrot",
                "Dog",
                "Horse",
                "Brittle stars",
                "Centipede",
            ],
        }
    )


def create_encryption_config():
    return pe.EncryptionConfiguration(
        footer_key=FOOTER_KEY_NAME,
        plaintext_footer=False,
        column_keys={COL_KEY_NAME: ["n_legs", "animal"]},
        encryption_algorithm="AES_GCM_V1",
        # requires timedelta or an assertion is raised
        cache_lifetime=timedelta(minutes=5.0),
        data_key_length_bits=256,
    )


def create_decryption_config():
    return pe.DecryptionConfiguration(cache_lifetime=300)


def create_kms_connection_config():
    return pe.KmsConnectionConfig(
        custom_kms_conf={
            FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
            COL_KEY_NAME: COL_KEY.decode("UTF-8"),
        }
    )


def kms_factory(kms_connection_configuration):
    return InMemoryKmsClient(kms_connection_configuration)


@pytest.mark.skipif(
    encryption_unavailable, reason="Parquet Encryption is not currently enabled"
)
def test_dataset_encryption_decryption():
    table = create_sample_table()

    encryption_config = create_encryption_config()
    decryption_config = create_decryption_config()
    kms_connection_config = create_kms_connection_config()

    crypto_factory = pe.CryptoFactory(kms_factory)
    parquet_encryption_cfg = ds.ParquetEncryptionConfig(
        crypto_factory, kms_connection_config, encryption_config
    )
    parquet_decryption_cfg = ds.ParquetDecryptionConfig(
        crypto_factory, kms_connection_config, decryption_config
    )

    # create write_options with dataset encryption config
    pformat = pa.dataset.ParquetFileFormat()
    write_options = pformat.make_write_options(encryption_config=parquet_encryption_cfg)

    mockfs = fs._MockFileSystem()
    mockfs.create_dir("/")

    ds.write_dataset(
        data=table,
        base_dir="sample_dataset",
        format=pformat,
        file_options=write_options,
        filesystem=mockfs,
    )

    # read without decryption config -> should error is dataset was properly encrypted
    pformat = pa.dataset.ParquetFileFormat()
    with pytest.raises(IOError, match=r"no decryption"):
        ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

    # set decryption config for parquet fragment scan options
    pq_scan_opts = ds.ParquetFragmentScanOptions(
        decryption_config=parquet_decryption_cfg
    )
    pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
    dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

    assert table.equals(dataset.to_table())

    # set decryption properties for parquet fragment scan options
    decryption_properties = crypto_factory.file_decryption_properties(
        kms_connection_config, decryption_config)
    pq_scan_opts = ds.ParquetFragmentScanOptions(
        decryption_properties=decryption_properties
    )

    pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
    dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

    assert table.equals(dataset.to_table())


@pytest.mark.skipif(
    not encryption_unavailable, reason="Parquet Encryption is currently enabled"
)
def test_write_dataset_parquet_without_encryption():
    """Test write_dataset with ParquetFileFormat and test if an exception is thrown
    if you try to set encryption_config using make_write_options"""

    # Set the encryption configuration using ParquetFileFormat
    # and make_write_options
    pformat = pa.dataset.ParquetFileFormat()

    with pytest.raises(NotImplementedError):
        _ = pformat.make_write_options(encryption_config="some value")


@pytest.mark.skipif(
    encryption_unavailable, reason="Parquet Encryption is not currently enabled"
)
def test_large_row_encryption_decryption():
    """Test encryption and decryption of a large number of rows."""

    class NoOpKmsClient(pe.KmsClient):
        def wrap_key(self, key_bytes: bytes, _: str) -> bytes:
            b = base64.b64encode(key_bytes)
            return b

        def unwrap_key(self, wrapped_key: bytes, _: str) -> bytes:
            b = base64.b64decode(wrapped_key)
            return b

    row_count = 2**15 + 1
    table = pa.Table.from_arrays(
        [pa.array(np.random.rand(row_count), type=pa.float32())], names=["foo"]
    )

    kms_config = pe.KmsConnectionConfig()
    crypto_factory = pe.CryptoFactory(lambda _: NoOpKmsClient())
    encryption_config = pe.EncryptionConfiguration(
        footer_key="UNIMPORTANT_KEY",
        column_keys={"UNIMPORTANT_KEY": ["foo"]},
        double_wrapping=True,
        plaintext_footer=False,
        data_key_length_bits=128,
    )
    pqe_config = ds.ParquetEncryptionConfig(
        crypto_factory, kms_config, encryption_config
    )
    pqd_config = ds.ParquetDecryptionConfig(
        crypto_factory, kms_config, pe.DecryptionConfiguration()
    )
    scan_options = ds.ParquetFragmentScanOptions(decryption_config=pqd_config)
    file_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options)
    write_options = file_format.make_write_options(encryption_config=pqe_config)
    file_decryption_properties = crypto_factory.file_decryption_properties(kms_config)

    mockfs = fs._MockFileSystem()
    mockfs.create_dir("/")

    path = "large-row-test-dataset"
    ds.write_dataset(table, path, format=file_format,
                     file_options=write_options, filesystem=mockfs)

    file_path = path + "/part-0.parquet"
    new_table = pq.ParquetFile(
        file_path, decryption_properties=file_decryption_properties,
        filesystem=mockfs
    ).read()
    assert table == new_table

    dataset = ds.dataset(path, format=file_format, filesystem=mockfs)
    new_table = dataset.to_table()
    assert table == new_table