Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-nightlies / pyarrow   python

Repository URL to install this package:

Version: 19.0.0.dev259 

/ include / arrow / adapters / tensorflow / convert.h

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <memory>

#include "tensorflow/core/framework/op.h"

#include "arrow/type.h"

// These utilities are supposed to be included in TensorFlow operators
// that need to be compiled separately from Arrow because of ABI issues.
// They therefore need to be header-only.

namespace arrow {

namespace adapters {

namespace tensorflow {

Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr<DataType>* out) {
  switch (dtype) {
    case ::tensorflow::DT_BOOL:
      *out = arrow::boolean();
      break;
    case ::tensorflow::DT_FLOAT:
      *out = arrow::float32();
      break;
    case ::tensorflow::DT_DOUBLE:
      *out = arrow::float64();
      break;
    case ::tensorflow::DT_HALF:
      *out = arrow::float16();
      break;
    case ::tensorflow::DT_INT8:
      *out = arrow::int8();
      break;
    case ::tensorflow::DT_INT16:
      *out = arrow::int16();
      break;
    case ::tensorflow::DT_INT32:
      *out = arrow::int32();
      break;
    case ::tensorflow::DT_INT64:
      *out = arrow::int64();
      break;
    case ::tensorflow::DT_UINT8:
      *out = arrow::uint8();
      break;
    case ::tensorflow::DT_UINT16:
      *out = arrow::uint16();
      break;
    case ::tensorflow::DT_UINT32:
      *out = arrow::uint32();
      break;
    case ::tensorflow::DT_UINT64:
      *out = arrow::uint64();
      break;
    default:
      return Status::TypeError("TensorFlow data type is not supported");
  }
  return Status::OK();
}

Status GetTensorFlowType(std::shared_ptr<DataType> dtype, ::tensorflow::DataType* out) {
  switch (dtype->id()) {
    case Type::BOOL:
      *out = ::tensorflow::DT_BOOL;
      break;
    case Type::UINT8:
      *out = ::tensorflow::DT_UINT8;
      break;
    case Type::INT8:
      *out = ::tensorflow::DT_INT8;
      break;
    case Type::UINT16:
      *out = ::tensorflow::DT_UINT16;
      break;
    case Type::INT16:
      *out = ::tensorflow::DT_INT16;
      break;
    case Type::UINT32:
      *out = ::tensorflow::DT_UINT32;
      break;
    case Type::INT32:
      *out = ::tensorflow::DT_INT32;
      break;
    case Type::UINT64:
      *out = ::tensorflow::DT_UINT64;
      break;
    case Type::INT64:
      *out = ::tensorflow::DT_INT64;
      break;
    case Type::HALF_FLOAT:
      *out = ::tensorflow::DT_HALF;
      break;
    case Type::FLOAT:
      *out = ::tensorflow::DT_FLOAT;
      break;
    case Type::DOUBLE:
      *out = ::tensorflow::DT_DOUBLE;
      break;
    default:
      return Status::TypeError("Arrow data type is not supported");
  }
  return arrow::Status::OK();
}

}  // namespace tensorflow

}  // namespace adapters

}  // namespace arrow