Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / ATen / native / Repeat.h

#pragma once

#include <ATen/ATen.h>

namespace at { namespace native {

template <void compute(int64_t *, int64_t *, int64_t *, int64_t)>
static inline Tensor repeat_interleave_common(const Tensor &repeats) {
    TORCH_CHECK(repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
    TORCH_CHECK(repeats.scalar_type() == at::kLong, "repeats has to be Long tensor");
    TORCH_CHECK((repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
    if (repeats.size(0) == 0) {
        return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    }
    Tensor repeats_ = repeats.contiguous();
    Tensor cumsum = repeats.cumsum(0);
    int64_t total = cumsum[-1].item<int64_t>();
    Tensor result = at::empty({total}, repeats.options());
    int64_t *repeat_ptr = repeats_.data_ptr<int64_t>();
    int64_t *cumsum_ptr = cumsum.data_ptr<int64_t>();
    int64_t *result_ptr = result.data_ptr<int64_t>();
    compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0));
    return result;
}

}}