#pragma once
namespace at { namespace native {
// (Const)StridedRandomAccessor is a
// (const) random access iterator defined over
// a strided array.
// The traits below are to introduce __restrict__
// modifier on different platforms.
template <typename T>
struct DefaultPtrTraits {
using PtrType = T*;
};
#if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif
template <typename T>
struct RestrictPtrTraits {
using PtrType = T* RESTRICT;
};
template <
typename T,
typename index_t = int64_t,
template <typename U> class PtrTraits = DefaultPtrTraits
>
class ConstStridedRandomAccessor {
public:
using difference_type = index_t;
using value_type = const T;
using pointer = const typename PtrTraits<T>::PtrType;
using reference = const value_type&;
using iterator_category = std::random_access_iterator_tag;
using PtrType = typename PtrTraits<T>::PtrType;
using index_type = index_t;
// Constructors {
C10_HOST_DEVICE
ConstStridedRandomAccessor(PtrType ptr, index_t stride)
: ptr{ptr}, stride{stride}
{}
C10_HOST_DEVICE
explicit ConstStridedRandomAccessor(PtrType ptr)
: ptr{ptr}, stride{static_cast<index_t>(1)}
{}
C10_HOST_DEVICE
ConstStridedRandomAccessor()
: ptr{nullptr}, stride{static_cast<index_t>(1)}
{}
// }
// Pointer-like operations {
C10_HOST_DEVICE
reference operator*() const {
return *ptr;
}
C10_HOST_DEVICE
const value_type* operator->() const {
return reinterpret_cast<const value_type*>(ptr);
}
C10_HOST_DEVICE
reference operator[](index_t idx) const {
return ptr[idx * stride];
}
// }
// Prefix/postfix increment/decrement {
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator++() {
ptr += stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator++(int) {
ConstStridedRandomAccessor copy(*this);
++*this;
return copy;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator--() {
ptr -= stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator--(int) {
ConstStridedRandomAccessor copy(*this);
--*this;
return copy;
}
// }
// Arithmetic operations {
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator+=(index_t offset) {
ptr += offset * stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator+(index_t offset) const {
return ConstStridedRandomAccessor(ptr + offset * stride, stride);
}
C10_HOST_DEVICE
friend ConstStridedRandomAccessor operator+(
index_t offset,
const ConstStridedRandomAccessor& accessor
) {
return accessor + offset;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor& operator-=(index_t offset) {
ptr -= offset * stride;
return *this;
}
C10_HOST_DEVICE
ConstStridedRandomAccessor operator-(index_t offset) const {
return ConstStridedRandomAccessor(ptr - offset * stride, stride);
}
// Note that this operator is well-defined when `this` and `other`
// represent the same sequences, i.e. when
// 1. this.stride == other.stride,
// 2. |other - this| / this.stride is an Integer.
C10_HOST_DEVICE
difference_type operator-(const ConstStridedRandomAccessor& other) const {
return (ptr - other.ptr) / stride;
}
// }
// Comparison operators {
C10_HOST_DEVICE
bool operator==(const ConstStridedRandomAccessor& other) const {
return (ptr == other.ptr) && (stride == other.stride);
}
C10_HOST_DEVICE
bool operator!=(const ConstStridedRandomAccessor& other) const {
return !(*this == other);
}
C10_HOST_DEVICE
bool operator<(const ConstStridedRandomAccessor& other) const {
return ptr < other.ptr;
}
C10_HOST_DEVICE
bool operator<=(const ConstStridedRandomAccessor& other) const {
return (*this < other) || (*this == other);
}
C10_HOST_DEVICE
bool operator>(const ConstStridedRandomAccessor& other) const {
return !(*this <= other);
}
C10_HOST_DEVICE
bool operator>=(const ConstStridedRandomAccessor& other) const {
return !(*this < other);
}
// }
protected:
PtrType ptr;
index_t stride;
};
template <
typename T,
typename index_t = int64_t,
template <typename U> class PtrTraits = DefaultPtrTraits
>
class StridedRandomAccessor
: public ConstStridedRandomAccessor<T, index_t, PtrTraits> {
public:
using difference_type = index_t;
using value_type = T;
using pointer = typename PtrTraits<T>::PtrType;
using reference = value_type&;
using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>;
using PtrType = typename PtrTraits<T>::PtrType;
// Constructors {
C10_HOST_DEVICE
StridedRandomAccessor(PtrType ptr, index_t stride)
: BaseType(ptr, stride)
{}
C10_HOST_DEVICE
explicit StridedRandomAccessor(PtrType ptr)
: BaseType(ptr)
{}
C10_HOST_DEVICE
StridedRandomAccessor()
: BaseType()
{}
// }
// Pointer-like operations {
C10_HOST_DEVICE
reference operator*() const {
return *this->ptr;
}
C10_HOST_DEVICE
value_type* operator->() const {
return reinterpret_cast<value_type*>(this->ptr);
}
C10_HOST_DEVICE
reference operator[](index_t idx) const {
return this->ptr[idx * this->stride];
}
// }
// Prefix/postfix increment/decrement {
C10_HOST_DEVICE
StridedRandomAccessor& operator++() {
this->ptr += this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator++(int) {
StridedRandomAccessor copy(*this);
++*this;
return copy;
}
C10_HOST_DEVICE
StridedRandomAccessor& operator--() {
this->ptr -= this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator--(int) {
StridedRandomAccessor copy(*this);
--*this;
return copy;
}
// }
// Arithmetic operations {
C10_HOST_DEVICE
StridedRandomAccessor& operator+=(index_t offset) {
this->ptr += offset * this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator+(index_t offset) const {
return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
}
C10_HOST_DEVICE
friend StridedRandomAccessor operator+(
index_t offset,
const StridedRandomAccessor& accessor
) {
return accessor + offset;
}
C10_HOST_DEVICE
StridedRandomAccessor& operator-=(index_t offset) {
this->ptr -= offset * this->stride;
return *this;
}
C10_HOST_DEVICE
StridedRandomAccessor operator-(index_t offset) const {
return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
}
// Note that here we call BaseType::operator- version
C10_HOST_DEVICE
difference_type operator-(const BaseType& other) const {
return (static_cast<const BaseType&>(*this) - other);
}
// }
};
}} // namespace at::native