#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/reverse_iterator.h>
#include <cstring>
#include <limits>
#include <stdexcept>
#include <string>
#include <algorithm>
namespace c10 {
/**
* Reimplementation of std::string_view for C++11.
* Implemented following the interface definition in
* https://en.cppreference.com/w/cpp/string/basic_string_view
* See there for the API documentation.
*
* Difference: We don't have a Traits template parameter because
* std::char_traits isn't constexpr and we'd have to reimplement
* std::char_traits if we wanted to use it with our constexpr basic_string_view.
*/
template <class CharT>
class basic_string_view final {
public:
using value_type = CharT;
using pointer = CharT*;
using const_pointer = const CharT*;
using reference = CharT&;
using const_reference = const CharT&;
using const_iterator = const CharT*;
using iterator = const_iterator;
using const_reverse_iterator = c10::reverse_iterator<const_iterator>;
using reverse_iterator = const_reverse_iterator;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
static constexpr size_type npos = size_type(-1);
constexpr basic_string_view() noexcept : begin_(nullptr), size_(0) {}
explicit constexpr basic_string_view(const_pointer str, size_type count)
: begin_(str), size_(count) {}
/* implicit */ constexpr basic_string_view(const_pointer str)
: basic_string_view(str, strlen_(str)) {}
/* implicit */ basic_string_view(const ::std::basic_string<CharT>& str)
: basic_string_view(str.data(), str.size()) {}
constexpr basic_string_view(const basic_string_view&) noexcept = default;
constexpr basic_string_view& operator=(
const basic_string_view& rhs) noexcept {
begin_ = rhs.begin_;
size_ = rhs.size_;
return *this;
}
explicit operator ::std::basic_string<CharT>() const {
return ::std::basic_string<CharT>(data(), size());
}
constexpr const_iterator begin() const noexcept {
return cbegin();
}
constexpr const_iterator cbegin() const noexcept {
return begin_;
}
constexpr const_iterator end() const noexcept {
return cend();
}
constexpr const_iterator cend() const noexcept {
return begin_ + size_;
}
constexpr const_reverse_iterator rbegin() const noexcept {
return crbegin();
}
constexpr const_reverse_iterator crbegin() const noexcept {
return const_reverse_iterator(this->end());
}
constexpr const_reverse_iterator rend() const noexcept {
return crend();
}
constexpr const_reverse_iterator crend() const noexcept {
return const_reverse_iterator(this->begin());
}
friend constexpr const_iterator begin(basic_string_view sv) noexcept {
return sv.begin();
}
friend constexpr const_iterator end(basic_string_view sv) noexcept {
return sv.end();
}
constexpr const_reference operator[](size_type pos) const {
return at(pos);
}
constexpr const_reference at(size_type pos) const {
#if !defined( \
__CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
return C10_UNLIKELY(pos >= size_)
? (throw std::out_of_range(
"string_view::operator[] or string_view::at() out of range. Index: " +
c10::guts::to_string(pos) + ", size: " + c10::guts::to_string(size())),
at_(0))
: at_(pos);
#else
return at_(pos);
#endif
}
constexpr const_reference front() const {
return *begin_;
}
constexpr const_reference back() const {
return *(begin_ + size_ - 1);
}
constexpr const_pointer data() const noexcept {
return begin_;
}
constexpr size_type size() const noexcept {
return size_;
}
constexpr size_type length() const noexcept {
return size();
}
constexpr size_type max_size() const noexcept {
return std::numeric_limits<difference_type>::max();
}
C10_NODISCARD constexpr bool empty() const noexcept {
return size() == 0;
}
CONSTEXPR_EXCEPT_GCC5 void remove_prefix(size_type n) {
if (n > size()) {
throw std::out_of_range(
"basic_string_view::remove_prefix: out of range. PrefixLength: " +
c10::guts::to_string(n) + ", size: " + c10::guts::to_string(size()));
}
begin_ += n;
size_ -= n;
}
CONSTEXPR_EXCEPT_GCC5 void remove_suffix(size_type n) {
if (n > size()) {
throw std::out_of_range(
"basic_string_view::remove_suffix: out of range. SuffixLength: " +
c10::guts::to_string(n) + ", size: " + c10::guts::to_string(size()));
}
size_ -= n;
}
CONSTEXPR_EXCEPT_GCC5 void swap(basic_string_view& sv) noexcept {
auto tmp = *this;
*this = sv;
sv = tmp;
}
size_type copy(pointer dest, size_type count, size_type pos = 0) const {
if (pos > size_) {
throw std::out_of_range(
"basic_string_view::copy: out of range. Index: " +
c10::guts::to_string(pos) +
", size: " + c10::guts::to_string(size()));
}
size_type copy_length = guts::min(count, size_ - pos);
for (auto iter = begin() + pos, end = iter + copy_length; iter != end;) {
*(dest++) = *(iter++);
}
return copy_length;
}
constexpr basic_string_view substr(size_type pos = 0, size_type count = npos)
const {
#if !defined( \
__CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
return (pos > size_)
? (throw std::out_of_range(
"basic_string_view::substr parameter out of bounds. Index: " +
c10::guts::to_string(pos) + ", size: " + c10::guts::to_string(size())),
substr_())
: substr_(pos, count);
#else
return substr_(pos, count);
#endif
}
constexpr int compare(basic_string_view rhs) const noexcept {
#if __cpp_constexpr >= 201304
// if we are in C++14, write it iteratively. This is faster.
for (size_t i = 0, end = guts::min(size(), rhs.size()); i < end; ++i) {
if (at_(i) < rhs.at_(i)) {
return -1;
} else if (at_(i) > rhs.at_(i)) {
return 1;
}
}
if (size() < rhs.size()) {
return -1;
} else if (size() > rhs.size()) {
return 1;
}
return 0;
#else
// if we are in C++11, we need to do it recursively because of constexpr
// restrictions.
return (size() == 0 && rhs.size() == 0) ? 0
: (size() == 0) ? -1
: (rhs.size() == 0)
? 1
: (front() < rhs.front()) ? -1
: (front() > rhs.front())
? 1
: substr_(1).compare(rhs.substr_(1));
#endif
}
constexpr int compare(size_type pos1, size_type count1, basic_string_view v)
const {
return substr(pos1, count1).compare(v);
}
constexpr int compare(
size_type pos1,
size_type count1,
basic_string_view v,
size_type pos2,
size_type count2) const {
return substr(pos1, count1).compare(v.substr(pos2, count2));
}
constexpr int compare(const_pointer s) const {
return compare(basic_string_view(s));
}
constexpr int compare(size_type pos1, size_type count1, const_pointer s)
const {
return substr(pos1, count1).compare(basic_string_view(s));
}
constexpr int compare(
size_type pos1,
size_type count1,
const_pointer s,
size_type count2) const {
return substr(pos1, count1).compare(basic_string_view(s, count2));
}
friend constexpr bool operator==(
basic_string_view lhs,
basic_string_view rhs) noexcept {
return lhs.equals_(rhs);
}
friend constexpr bool operator!=(
basic_string_view lhs,
basic_string_view rhs) noexcept {
return !(lhs == rhs);
}
friend constexpr bool operator<(
basic_string_view lhs,
basic_string_view rhs) noexcept {
return lhs.compare(rhs) < 0;
}
friend constexpr bool operator>=(
basic_string_view lhs,
basic_string_view rhs) noexcept {
return !(lhs < rhs);
}
friend constexpr bool operator>(
basic_string_view lhs,
basic_string_view rhs) noexcept {
return rhs < lhs;
}
friend constexpr bool operator<=(
basic_string_view lhs,
basic_string_view rhs) noexcept {
return !(lhs > rhs);
}
constexpr bool starts_with(basic_string_view prefix) const noexcept {
return (prefix.size() > size()) ? false
: prefix.equals_(substr_(0, prefix.size()));
}
constexpr bool starts_with(CharT prefix) const noexcept {
return !empty() && prefix == front();
}
constexpr bool starts_with(const_pointer prefix) const {
return starts_with(basic_string_view(prefix));
}
constexpr bool ends_with(basic_string_view suffix) const noexcept {
return (suffix.size() > size())
? false
: suffix.equals_(substr_(size() - suffix.size(), suffix.size()));
}
constexpr bool ends_with(CharT suffix) const noexcept {
return !empty() && suffix == back();
}
constexpr bool ends_with(const_pointer suffix) const {
return ends_with(basic_string_view(suffix));
}
constexpr size_type find(basic_string_view v, size_type pos = 0) const
noexcept {
#if __cpp_constexpr >= 201304
// if we are in C++14, write it iteratively. This is faster.
if (v.size() == 0) {
return pos <= size() ? pos : npos;
}
if (pos + v.size() <= size()) {
for (size_type cur = pos, end = size() - v.size(); cur <= end; ++cur) {
if (v.at_(0) == at_(cur) &&
v.substr_(1).equals_(substr_(cur + 1, v.size() - 1))) {
return cur;
}
}
}
return npos;
Loading ...