Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-adbc-nightlies / adbc-driver-manager   python

Repository URL to install this package:

Version: 1.3.0 

/ adbc_driver_manager.cc

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow-adbc/adbc_driver_manager.h"
#include "arrow-adbc/adbc.h"

#include <algorithm>
#include <array>
#include <cctype>
#include <cerrno>
#include <cstring>
#include <string>
#include <unordered_map>
#include <utility>

#if defined(_WIN32)
#include <windows.h>  // Must come first

#include <libloaderapi.h>
#include <strsafe.h>
#else
#include <dlfcn.h>
#endif  // defined(_WIN32)

namespace {

// Platform-specific helpers

#if defined(_WIN32)
/// Append a description of the Windows error to the buffer.
void GetWinError(std::string* buffer) {
  DWORD rc = GetLastError();
  LPVOID message;

  FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
                    FORMAT_MESSAGE_IGNORE_INSERTS,
                /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
                reinterpret_cast<LPSTR>(&message), /*nSize=*/0, /*Arguments=*/nullptr);

  (*buffer) += '(';
  (*buffer) += std::to_string(rc);
  (*buffer) += ") ";
  (*buffer) += reinterpret_cast<char*>(message);
  LocalFree(message);
}

#endif  // defined(_WIN32)

// Error handling

void ReleaseError(struct AdbcError* error) {
  if (error) {
    if (error->message) delete[] error->message;
    error->message = nullptr;
    error->release = nullptr;
  }
}

void SetError(struct AdbcError* error, const std::string& message) {
  static const std::string kPrefix = "[Driver Manager] ";

  if (!error) return;
  if (error->release) error->release(error);

  // Prepend a string to identify driver manager errors
  error->message = new char[kPrefix.size() + message.size() + 1];
  kPrefix.copy(error->message, kPrefix.size());
  message.copy(error->message + kPrefix.size(), message.size());
  error->message[kPrefix.size() + message.size()] = '\0';
  error->release = ReleaseError;
}

// Copies src_error into error and releases src_error
void SetError(struct AdbcError* error, struct AdbcError* src_error) {
  if (!error) return;
  if (error->release) error->release(error);

  if (src_error->message) {
    size_t message_size = strlen(src_error->message);
    error->message = new char[message_size];
    std::memcpy(error->message, src_error->message, message_size);
    error->message[message_size] = '\0';
  } else {
    error->message = nullptr;
  }

  error->release = ReleaseError;
  if (src_error->release) {
    src_error->release(src_error);
  }
}

struct OwnedError {
  struct AdbcError error = ADBC_ERROR_INIT;

  ~OwnedError() {
    if (error.release) {
      error.release(&error);
    }
  }
};

// Driver state

/// A driver DLL.
struct ManagedLibrary {
  ManagedLibrary() : handle(nullptr) {}
  ManagedLibrary(ManagedLibrary&& other) : handle(other.handle) {
    other.handle = nullptr;
  }
  ManagedLibrary(const ManagedLibrary&) = delete;
  ManagedLibrary& operator=(const ManagedLibrary&) = delete;
  ManagedLibrary& operator=(ManagedLibrary&& other) noexcept {
    this->handle = other.handle;
    other.handle = nullptr;
    return *this;
  }

  ~ManagedLibrary() { Release(); }

  void Release() {
    // TODO(apache/arrow-adbc#204): causes tests to segfault
    // Need to refcount the driver DLL; also, errors may retain a reference to
    // release() from the DLL - how to handle this?
  }

  AdbcStatusCode Load(const char* library, struct AdbcError* error) {
    std::string error_message;
#if defined(_WIN32)
    HMODULE handle = LoadLibraryExA(library, NULL, 0);
    if (!handle) {
      error_message += library;
      error_message += ": LoadLibraryExA() failed: ";
      GetWinError(&error_message);

      std::string full_driver_name = library;
      full_driver_name += ".dll";
      handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0);
      if (!handle) {
        error_message += '\n';
        error_message += full_driver_name;
        error_message += ": LoadLibraryExA() failed: ";
        GetWinError(&error_message);
      }
    }
    if (!handle) {
      SetError(error, error_message);
      return ADBC_STATUS_INTERNAL;
    } else {
      this->handle = handle;
    }
#else
    static const std::string kPlatformLibraryPrefix = "lib";
#if defined(__APPLE__)
    static const std::string kPlatformLibrarySuffix = ".dylib";
#else
    static const std::string kPlatformLibrarySuffix = ".so";
#endif  // defined(__APPLE__)

    void* handle = dlopen(library, RTLD_NOW | RTLD_LOCAL);
    if (!handle) {
      error_message = "[DriverManager] dlopen() failed: ";
      error_message += dlerror();

      // If applicable, append the shared library prefix/extension and
      // try again (this way you don't have to hardcode driver names by
      // platform in the application)
      const std::string driver_str = library;

      std::string full_driver_name;
      if (driver_str.size() < kPlatformLibraryPrefix.size() ||
          driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) !=
              0) {
        full_driver_name += kPlatformLibraryPrefix;
      }
      full_driver_name += library;
      if (driver_str.size() < kPlatformLibrarySuffix.size() ||
          driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(),
                             kPlatformLibrarySuffix.size(),
                             kPlatformLibrarySuffix) != 0) {
        full_driver_name += kPlatformLibrarySuffix;
      }
      handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL);
      if (!handle) {
        error_message += "\ndlopen() failed: ";
        error_message += dlerror();
      }
    }
    if (handle) {
      this->handle = handle;
    } else {
      SetError(error, error_message);
      return ADBC_STATUS_INTERNAL;
    }
#endif  // defined(_WIN32)
    return ADBC_STATUS_OK;
  }

  AdbcStatusCode Lookup(const char* name, void** func, struct AdbcError* error) {
#if defined(_WIN32)
    void* load_handle = reinterpret_cast<void*>(GetProcAddress(handle, name));
    if (!load_handle) {
      std::string message = "GetProcAddress(";
      message += name;
      message += ") failed: ";
      GetWinError(&message);
      SetError(error, message);
      return ADBC_STATUS_INTERNAL;
    }
#else
    void* load_handle = dlsym(handle, name);
    if (!load_handle) {
      std::string message = "dlsym(";
      message += name;
      message += ") failed: ";
      message += dlerror();
      SetError(error, message);
      return ADBC_STATUS_INTERNAL;
    }
#endif  // defined(_WIN32)
    *func = load_handle;
    return ADBC_STATUS_OK;
  }

#if defined(_WIN32)
  // The loaded DLL
  HMODULE handle;
#else
  void* handle;
#endif  // defined(_WIN32)
};

/// Hold the driver DLL and the driver release callback in the driver struct.
struct ManagerDriverState {
  // The original release callback
  AdbcStatusCode (*driver_release)(struct AdbcDriver* driver, struct AdbcError* error);

  ManagedLibrary handle;
};

/// Unload the driver DLL.
static AdbcStatusCode ReleaseDriver(struct AdbcDriver* driver, struct AdbcError* error) {
  AdbcStatusCode status = ADBC_STATUS_OK;

  if (!driver->private_manager) return status;
  ManagerDriverState* state =
      reinterpret_cast<ManagerDriverState*>(driver->private_manager);

  if (state->driver_release) {
    status = state->driver_release(driver, error);
  }
  state->handle.Release();

  driver->private_manager = nullptr;
  delete state;
  return status;
}

// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream

struct ErrorArrayStream {
  struct ArrowArrayStream stream;
  struct AdbcDriver* private_driver;
};

void ErrorArrayStreamRelease(struct ArrowArrayStream* stream) {
  if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return;

  auto* private_data = reinterpret_cast<struct ErrorArrayStream*>(stream->private_data);
  private_data->stream.release(&private_data->stream);
  delete private_data;
  std::memset(stream, 0, sizeof(*stream));
}

const char* ErrorArrayStreamGetLastError(struct ArrowArrayStream* stream) {
  if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return nullptr;
  auto* private_data = reinterpret_cast<struct ErrorArrayStream*>(stream->private_data);
  return private_data->stream.get_last_error(&private_data->stream);
}

int ErrorArrayStreamGetNext(struct ArrowArrayStream* stream, struct ArrowArray* array) {
  if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL;
  auto* private_data = reinterpret_cast<struct ErrorArrayStream*>(stream->private_data);
  return private_data->stream.get_next(&private_data->stream, array);
}

int ErrorArrayStreamGetSchema(struct ArrowArrayStream* stream,
                              struct ArrowSchema* schema) {
  if (stream->release != ErrorArrayStreamRelease || !stream->private_data) return EINVAL;
  auto* private_data = reinterpret_cast<struct ErrorArrayStream*>(stream->private_data);
  return private_data->stream.get_schema(&private_data->stream, schema);
}

// Default stubs

int ErrorGetDetailCount(const struct AdbcError* error) { return 0; }

struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError* error, int index) {
  return {nullptr, nullptr, 0};
}

const struct AdbcError* ErrorFromArrayStream(struct ArrowArrayStream* stream,
                                             AdbcStatusCode* status) {
  return nullptr;
}

void ErrorArrayStreamInit(struct ArrowArrayStream* out,
                          struct AdbcDriver* private_driver) {
  if (!out || !out->release ||
      // Don't bother wrapping if driver didn't claim support
      private_driver->ErrorFromArrayStream == ErrorFromArrayStream) {
    return;
  }
  struct ErrorArrayStream* private_data = new ErrorArrayStream;
  private_data->stream = *out;
  private_data->private_driver = private_driver;
  out->get_last_error = ErrorArrayStreamGetLastError;
  out->get_next = ErrorArrayStreamGetNext;
  out->get_schema = ErrorArrayStreamGetSchema;
  out->release = ErrorArrayStreamRelease;
  out->private_data = private_data;
}

AdbcStatusCode DatabaseGetOption(struct AdbcDatabase* database, const char* key,
                                 char* value, size_t* length, struct AdbcError* error) {
  SetError(error, "AdbcDatabaseGetOption not implemented");
  return ADBC_STATUS_NOT_FOUND;
}

AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key,
                                      uint8_t* value, size_t* length,
Loading ...