change shape from int to size_t/int64_t for slice ops

This commit is contained in:
TFbunny 2021-01-06 15:52:17 -05:00
parent e540ac46a3
commit 804a9b67b4
6 changed files with 152 additions and 176 deletions

View File

@ -107,12 +107,9 @@ class SliceGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported.";
return false;
}
std::vector<int64_t> size_me = GetAttr<std::vector<int64_t>>(kernel_node, "size");
std::vector<int64_t> begin_me = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
(void)std::transform(size_me.begin(), size_me.end(), std::back_inserter(size_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
[](const int64_t &value) { return static_cast<int>(value); });
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "size");
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
for (size_t i = 0; i < input_shape.size(); i++) {
if (input_shape[i] <= 0 || size_[i] <= 0) {
MS_LOG(WARNING) << "Slice output is null.";
@ -121,8 +118,8 @@ class SliceGpuFwdKernel : public GpuKernel {
}
return true;
}
std::vector<int> begin_;
std::vector<int> size_;
std::vector<int64_t> begin_;
std::vector<int64_t> size_;
std::vector<size_t> input_shape_;
std::vector<size_t> input_size_list_;

View File

@ -51,38 +51,27 @@ class SliceGradGpuKernel : public GpuKernel {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "StridedSliceGrad") {
is_strided_slice_ = true;
std::vector<int> shapex;
std::vector<int64_t> shapex_me = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
(void)std::transform(shapex_me.begin(), shapex_me.end(), std::back_inserter(shapex),
[](const int64_t &value) { return static_cast<int>(value); });
std::vector<int64_t> shapex = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
for (auto x : shapex) {
input_shape_.push_back(IntToSize(x));
input_shape_.push_back(static_cast<size_t>(x));
}
for (auto i = input_shape_.size(); i < 4; i++) {
(void)input_shape_.insert(input_shape_.begin(), 1);
}
std::vector<int64_t> strides_me = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
[](const int64_t &value) { return static_cast<int>(value); });
strides_ = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
for (auto i = strides_.size(); i < 4; i++) {
(void)strides_.insert(strides_.begin(), 1);
}
std::vector<int64_t> size_me = GetAttr<std::vector<int64_t>>(kernel_node, "end");
(void)std::transform(size_me.begin(), size_me.end(), std::back_inserter(size_),
[](const int64_t &value) { return static_cast<int>(value); });
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
} else {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
ShapeNdTo4d(input_shape, &input_shape_);
std::vector<int64_t> size_me = GetAttr<std::vector<int64_t>>(kernel_node, "size");
(void)std::transform(size_me.begin(), size_me.end(), std::back_inserter(size_),
[](const int64_t &value) { return static_cast<int>(value); });
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "size");
}
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
ShapeNdTo4d(dy_shape, &dy_shape_);
std::vector<int64_t> begin_me = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
[](const int64_t &value) { return static_cast<int>(value); });
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
DealParam();
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T);
@ -137,9 +126,9 @@ class SliceGradGpuKernel : public GpuKernel {
}
}
}
std::vector<int> begin_;
std::vector<int> size_;
std::vector<int> strides_;
std::vector<int64_t> begin_;
std::vector<int64_t> size_;
std::vector<int64_t> strides_;
std::vector<size_t> input_shape_;
std::vector<size_t> dy_shape_;
std::vector<size_t> input_size_list_;

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 8;
constexpr size_t MAX_DIMS = 8;
template <typename T>
class StridedSliceGpuKernel : public GpuKernel {
public:
@ -48,6 +48,7 @@ class StridedSliceGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape_.size() > MAX_DIMS) {
@ -82,27 +83,21 @@ class StridedSliceGpuKernel : public GpuKernel {
private:
void FillEmptyDims(const CNodePtr &kernel_node) {
std::vector<int64_t> begin_me = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
std::vector<int64_t> end_me = GetAttr<std::vector<int64_t>>(kernel_node, "end");
std::vector<int64_t> strides_me = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
[](const int64_t &value) { return static_cast<int>(value); });
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
end_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
strides_ = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
for (size_t i = 0; i < MAX_DIMS; i++) {
if (i < begin_.size()) {
int dim = SizeToInt(input_shape_[i]);
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, 0) : begin_[i], dim - 1);
int64_t dim = input_shape_[i];
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, static_cast<int64_t>(0)) : begin_[i], dim - 1);
} else {
begin_.push_back(0);
}
if (i < end_.size()) {
int dim = SizeToInt(input_shape_[i]);
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), -1);
int64_t dim = input_shape_[i];
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), static_cast<int64_t>(-1));
} else {
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
}
@ -164,11 +159,11 @@ class StridedSliceGpuKernel : public GpuKernel {
}
}
std::vector<bool> Dec2Bin(const int &mask) {
std::vector<bool> Dec2Bin(const int64_t &mask) {
auto mask_str = std::bitset<MAX_DIMS>(mask).to_string();
int dim_idx = 0;
int64_t dim_idx = 0;
std::vector<bool> result = {false, false, false, false};
for (int i = mask_str.size() - 1; i >= 0; i--) {
for (int64_t i = mask_str.size() - 1; i >= 0; i--) {
if (mask_str[i] == '1') {
result[dim_idx] = true;
}
@ -178,7 +173,7 @@ class StridedSliceGpuKernel : public GpuKernel {
}
void FillOutputDim() {
for (int i = 0; i < MAX_DIMS; i++) {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] <= end_[i] && strides_[i] > 0) {
output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1);
} else if (begin_[i] > end_[i] && strides_[i] < 0) {
@ -190,7 +185,7 @@ class StridedSliceGpuKernel : public GpuKernel {
}
bool IsNullOutput() {
for (int i = 0; i < MAX_DIMS; i++) {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] >= end_[i] && strides_[i] > 0) {
return true;
}
@ -201,12 +196,12 @@ class StridedSliceGpuKernel : public GpuKernel {
return false;
}
std::vector<int> begin_;
std::vector<int> end_;
std::vector<int> strides_;
std::vector<int64_t> begin_;
std::vector<int64_t> end_;
std::vector<int64_t> strides_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
int null_output_;
bool null_output_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
constexpr size_t MAX_DIMS = 7;
template <typename T>
class StridedSliceGradGpuKernel : public GpuKernel {
public:
@ -50,12 +50,9 @@ class StridedSliceGradGpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) override {
std::vector<int> shapex;
std::vector<int64_t> shapex_me = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
(void)std::transform(shapex_me.begin(), shapex_me.end(), std::back_inserter(shapex),
[](const int64_t &value) { return static_cast<int>(value); });
std::vector<int64_t> shapex = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
for (auto x : shapex) {
input_shape_.push_back(IntToSize(x));
input_shape_.push_back(static_cast<size_t>(x));
}
if (input_shape_.size() > MAX_DIMS) {
MS_LOG(ERROR) << "StridedSliceGrad support support dims less than " << input_shape_.size();
@ -87,27 +84,21 @@ class StridedSliceGradGpuKernel : public GpuKernel {
private:
void FillEmptyDims(const CNodePtr &kernel_node) {
std::vector<int64_t> begin_me = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
std::vector<int64_t> end_me = GetAttr<std::vector<int64_t>>(kernel_node, "end");
std::vector<int64_t> strides_me = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
[](const int64_t &value) { return static_cast<int>(value); });
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
end_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
strides_ = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
for (size_t i = 0; i < MAX_DIMS; i++) {
if (i < begin_.size()) {
int dim = SizeToInt(input_shape_[i]);
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, 0) : begin_[i], dim - 1);
int64_t dim = input_shape_[i];
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, static_cast<int64_t>(0)) : begin_[i], dim - 1);
} else {
begin_.push_back(0);
}
if (i < end_.size()) {
int dim = SizeToInt(input_shape_[i]);
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), -1);
int64_t dim = input_shape_[i];
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), static_cast<int64_t>(-1));
} else {
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
}
@ -169,11 +160,11 @@ class StridedSliceGradGpuKernel : public GpuKernel {
}
}
std::vector<bool> Dec2Bin(const int &mask) {
std::vector<bool> Dec2Bin(const int64_t &mask) {
auto mask_str = std::bitset<MAX_DIMS>(mask).to_string();
int dim_idx = 0;
int64_t dim_idx = 0;
std::vector<bool> result = {false, false, false, false};
for (int i = mask_str.size() - 1; i >= 0; i--) {
for (int64_t i = mask_str.size() - 1; i >= 0; i--) {
if (mask_str[i] == '1') {
result[dim_idx] = true;
}
@ -183,7 +174,7 @@ class StridedSliceGradGpuKernel : public GpuKernel {
}
void FillOutputDim() {
for (int i = 0; i < MAX_DIMS; i++) {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] <= end_[i] && strides_[i] > 0) {
output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1);
} else if (begin_[i] > end_[i] && strides_[i] < 0) {
@ -195,7 +186,7 @@ class StridedSliceGradGpuKernel : public GpuKernel {
}
bool IsNullOutput() {
for (int i = 0; i < MAX_DIMS; i++) {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] >= end_[i] && strides_[i] > 0) {
return true;
}
@ -206,12 +197,12 @@ class StridedSliceGradGpuKernel : public GpuKernel {
return false;
}
std::vector<int> begin_;
std::vector<int> end_;
std::vector<int> strides_;
std::vector<int64_t> begin_;
std::vector<int64_t> end_;
std::vector<int64_t> strides_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
int null_output_;
bool null_output_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -21,10 +21,9 @@
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"
template <typename T>
__global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const T *input, T *output) {
__global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const T *input, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) {
size_t i = pos / (l2 * l3 * l4) % l1;
size_t j = pos / (l3 * l4) % l2;
@ -36,7 +35,7 @@ __global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const
}
}
template <typename T>
__global__ void SliceGrad(const T *dy, int p, int start, int length, T *output) {
__global__ void SliceGrad(const T *dy, int64_t p, int64_t start, int64_t length, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) {
output[start + pos] = dy[p + pos];
}
@ -57,24 +56,24 @@ void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaSt
return;
}
template <typename T>
void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const T *input, T *output, cudaStream_t stream) {
void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2,
const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const T *input, T *output, cudaStream_t stream) {
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
input, output);
}
template <typename T>
void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, T *output,
const std::vector<int64_t> begin, const std::vector<int64_t> size, T *output,
cudaStream_t cuda_stream) {
size_t block = in_shape[1] * in_shape[2] * in_shape[3];
size_t map = in_shape[2] * in_shape[3];
size_t w = in_shape[3];
int length = size[3];
int p = 0;
for (int i = begin[0]; i < size[0] + begin[0]; i++) {
for (int j = begin[1]; j < size[1] + begin[1]; j++) {
for (int k = begin[2]; k < size[2] + begin[2]; k++) {
int64_t length = size[3];
int64_t p = 0;
for (int64_t i = begin[0]; i < size[0] + begin[0]; i++) {
for (int64_t j = begin[1]; j < size[1] + begin[1]; j++) {
for (int64_t k = begin[2]; k < size[2] + begin[2]; k++) {
SliceGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(
dy, p, i * block + j * map + k * w + begin[3], length, output);
p = p + size[3];
@ -89,9 +88,9 @@ __global__ void StridedSliceKernel(const size_t b0, const size_t b1, const size_
const size_t s3, const size_t s4, const size_t s5, const size_t s6, const size_t i0,
const size_t i1, const size_t i2, const size_t i3, const size_t i4, const size_t i5,
const size_t i6, const size_t o0, const size_t o1, const size_t o2, const size_t o3,
const size_t o4, const size_t o5, const size_t o6,
const T *input_addr, T *output_addr) {
int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
const size_t o4, const size_t o5, const size_t o6, const T *input_addr,
T *output_addr) {
size_t output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
@ -101,25 +100,24 @@ __global__ void StridedSliceKernel(const size_t b0, const size_t b1, const size_
size_t n = pos / (o6) % o5;
size_t o = pos % o6;
size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
+ (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \
+ (n * s5 + b5) * i6 + (o * s6 + b6);
size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 +
(k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 +
(n * s5 + b5) * i6 + (o * s6 + b6);
output_addr[pos] = input_addr[input_idx];
}
}
template <typename T>
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const T *input, T *output,
cudaStream_t cuda_stream) {
size_t size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \
* output_shape[4] * output_shape[5] * output_shape[6];
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, const T *input,
T *output, cudaStream_t cuda_stream) {
size_t size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * output_shape[4] *
output_shape[5] * output_shape[6];
StridedSliceKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6],
strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6],
input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4], input_shape[5], input_shape[6],
output_shape[0], output_shape[1], output_shape[2], output_shape[3], output_shape[4], output_shape[5],
output_shape[6], input, output);
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], strides[0], strides[1], strides[2],
strides[3], strides[4], strides[5], strides[6], input_shape[0], input_shape[1], input_shape[2], input_shape[3],
input_shape[4], input_shape[5], input_shape[6], output_shape[0], output_shape[1], output_shape[2], output_shape[3],
output_shape[4], output_shape[5], output_shape[6], input, output);
}
template <typename T>
@ -129,9 +127,9 @@ __global__ void StridedSliceGradKernel(const size_t b0, const size_t b1, const s
const size_t s5, const size_t s6, const size_t i0, const size_t i1,
const size_t i2, const size_t i3, const size_t i4, const size_t i5,
const size_t i6, const size_t o0, const size_t o1, const size_t o2,
const size_t o3, const size_t o4, const size_t o5, const size_t o6,
const T *dy, T *dx) {
int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
const size_t o3, const size_t o4, const size_t o5, const size_t o6, const T *dy,
T *dx) {
size_t output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
@ -141,24 +139,23 @@ __global__ void StridedSliceGradKernel(const size_t b0, const size_t b1, const s
size_t n = pos / (o6) % o5;
size_t o = pos % o6;
size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
+ (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \
+ (n * s5 + b5) * i6 + (o * s6 + b6);
size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 +
(k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 +
(n * s5 + b5) * i6 + (o * s6 + b6);
dx[input_idx] = dy[pos];
}
return;
}
template <typename T>
void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape,
const T *dy, T *dx, cudaStream_t cuda_stream) {
void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const T *dy, T *dx,
cudaStream_t cuda_stream) {
size_t size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6];
StridedSliceGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6],
strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6],
dx_shape[0], dx_shape[1], dx_shape[2], dx_shape[3], dx_shape[4], dx_shape[5], dx_shape[6],
dy_shape[0], dy_shape[1], dy_shape[2], dy_shape[3], dy_shape[4], dy_shape[5], dy_shape[6],
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], strides[0], strides[1], strides[2],
strides[3], strides[4], strides[5], strides[6], dx_shape[0], dx_shape[1], dx_shape[2], dx_shape[3], dx_shape[4],
dx_shape[5], dx_shape[6], dy_shape[0], dy_shape[1], dy_shape[2], dy_shape[3], dy_shape[4], dy_shape[5], dy_shape[6],
dy, dx);
}
@ -167,7 +164,7 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream);
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, float *output,
const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream);
@ -175,7 +172,7 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream);
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, half *output,
const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<int>(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream);
@ -183,16 +180,20 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream);
template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, int *output,
const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<short>(const size_t input_size, short *addr, const float value, cudaStream_t cuda_stream); // NOLINT
template void FillDeviceArray<short>(const size_t input_size, short *addr, const float value, // NOLINT
cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const short *input, short *output, cudaStream_t stream); // NOLINT
template void CalSliceGrad<short>(const size_t input_size, const short *dy, const std::vector<size_t> in_shape, // NOLINT
const std::vector<int> begin, const std::vector<int> size, short *output, // NOLINT
cudaStream_t cuda_stream);
const size_t d3, const size_t d4, const short *input, short *output, // NOLINT
cudaStream_t stream);
template void CalSliceGrad<short>(const size_t input_size, const short *dy, // NOLINT
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size,
short *output, // NOLINT
cudaStream_t cuda_stream);
template void FillDeviceArray<unsigned char>(const size_t input_size, unsigned char *addr, const float value,
cudaStream_t cuda_stream);
@ -201,8 +202,9 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output,
cudaStream_t stream);
template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy,
const std::vector<size_t> in_shape, const std::vector<int> begin,
const std::vector<int> size, unsigned char *output, cudaStream_t cuda_stream);
const std::vector<size_t> in_shape, const std::vector<int64_t> begin,
const std::vector<int64_t> size, unsigned char *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<int64_t>(const size_t input_size, int64_t *addr, const float value,
cudaStream_t cuda_stream);
@ -211,57 +213,58 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c
const size_t d3, const size_t d4, const int64_t *input, int64_t *output,
cudaStream_t stream);
template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, int64_t *output,
cudaStream_t cuda_stream);
const std::vector<int64_t> begin, const std::vector<int64_t> size, int64_t *output,
cudaStream_t cuda_stream);
template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, bool *output,
cudaStream_t cuda_stream);
const std::vector<int64_t> begin, const std::vector<int64_t> size, bool *output,
cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const float *input,
float *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const half *input,
half *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const int *input,
int *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape,
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const float *input, float *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const half *input, half *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const int *input, int *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const short *input, short *output, cudaStream_t cuda_stream); // NOLINT
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape,
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const unsigned char *input, unsigned char *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const bool *input,
bool *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape,
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const bool *input, bool *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape,
const int64_t *input, int64_t *output, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const float *dy,
float *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const half *dy,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const float *dy, float *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const half *dy,
half *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const int *dy,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const int *dy,
int *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const short *dy, // NOLINT
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const short *dy, // NOLINT
short *dx, cudaStream_t cuda_stream); // NOLINT
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const bool *dy,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const bool *dy,
bool *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const int64_t *dy,
int64_t *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape,
const int64_t *dy, int64_t *dx, cudaStream_t cuda_stream);

View File

@ -27,14 +27,15 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size
const T *input, T *output, cudaStream_t stream);
template <typename T>
void CalSliceGrad(const size_t input_size, const T *input, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, T *output, cudaStream_t cuda_stream);
template <typename T>
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const T *input, T *output,
const std::vector<int64_t> begin, const std::vector<int64_t> size, T *output,
cudaStream_t cuda_stream);
template <typename T>
void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const T *dy, T *dx,
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, const T *input,
T *output, cudaStream_t cuda_stream);
template <typename T>
void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const T *dy, T *dx,
cudaStream_t cuda_stream);
template <typename T>
void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream);