Add CumSum Kerner for CPU
This commit is contained in:
parent
efa7d78b57
commit
accf7c4671
|
@ -0,0 +1,248 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <thread>
|
||||
#include "backend/kernel_compiler/cpu/cumsum_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void CumSumCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
axis_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis"));
|
||||
dst_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
exclusive_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "exclusive");
|
||||
reverse_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reverse");
|
||||
int input_dim_length = SizeToInt(shape_.size());
|
||||
if (axis_ >= input_dim_length) {
|
||||
MS_LOG(EXCEPTION) << "Axis out of bounds.";
|
||||
}
|
||||
while (axis_ < 0) {
|
||||
axis_ += input_dim_length;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::InitWorkspaceSize() {
|
||||
input_size_0_ = sizeof(T);
|
||||
for (size_t i = 0; i < shape_.size(); i++) {
|
||||
input_size_0_ *= shape_[i];
|
||||
}
|
||||
workspace_size_list_.emplace_back(input_size_0_);
|
||||
}
|
||||
|
||||
void CumSumCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
CPUKernel::InitInputOutputSize(kernel_node);
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
InitWorkspaceSize<float_t>();
|
||||
} else if (dtype_ == kNumberTypeFloat16) {
|
||||
InitWorkspaceSize<float16>();
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
InitWorkspaceSize<int32_t>();
|
||||
} else if (dtype_ == kNumberTypeInt8) {
|
||||
InitWorkspaceSize<int8_t>();
|
||||
} else if (dtype_ == kNumberTypeUInt8) {
|
||||
InitWorkspaceSize<uint8_t>();
|
||||
}
|
||||
}
|
||||
|
||||
bool CumSumCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
Reshape();
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float_t>(inputs, workspace, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, workspace, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int32_t>(inputs, workspace, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt8) {
|
||||
LaunchKernel<int8_t>(inputs, workspace, outputs);
|
||||
} else if (dtype_ == kNumberTypeUInt8) {
|
||||
LaunchKernel<uint8_t>(inputs, workspace, outputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CumSumCPUKernel::Reshape() {
|
||||
dims_[0] = 1;
|
||||
dims_[1] = shape_[IntToSize(axis_)];
|
||||
dims_[2] = 1;
|
||||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
||||
dims_[0] *= shape_[i];
|
||||
}
|
||||
for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) {
|
||||
dims_[2] *= shape_[i];
|
||||
}
|
||||
stride_ = dims_[1] * dims_[2];
|
||||
stride2_ = dims_[2];
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
size_t k1 = i / dim2 % dim0;
|
||||
size_t k2 = i % dim2;
|
||||
size_t offset = k1 * stride + k2;
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == 0) {
|
||||
output[read_index] = (T)0;
|
||||
} else {
|
||||
size_t read_index2 = (j - 1) * stride2 + offset;
|
||||
output[read_index] = input[read_index2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
size_t k1 = i / dim2 % dim0;
|
||||
size_t k2 = i % dim2;
|
||||
size_t offset = k1 * stride + k2;
|
||||
for (int j = SizeToInt(dim1 - 1); j >= 0; --j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == SizeToInt(dim1 - 1)) {
|
||||
output[read_index] = (T)0;
|
||||
} else {
|
||||
size_t read_index2 = (j + 1) * stride2 + offset;
|
||||
output[read_index] = input[read_index2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::Copy(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
size_t k1 = i / dim2 % dim0;
|
||||
size_t k2 = i % dim2;
|
||||
size_t offset = k1 * stride + k2;
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
input[read_index] = output[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2,
|
||||
size_t stride, size_t stride2, size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
size_t k1 = i / dim2 % dim0;
|
||||
size_t k2 = i % dim2;
|
||||
size_t offset = k1 * stride + k2;
|
||||
for (int j = SizeToInt(dim1 - 1); j >= 0; --j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == SizeToInt(dim1 - 1)) {
|
||||
output[read_index] = input[read_index];
|
||||
} else {
|
||||
size_t read_index2 = (j + 1) * stride2 + offset;
|
||||
output[read_index] = output[read_index2] + input[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
size_t k1 = i / dim2 % dim0;
|
||||
size_t k2 = i % dim2;
|
||||
size_t offset = k1 * stride + k2;
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == 0) {
|
||||
output[read_index] = input[read_index];
|
||||
} else {
|
||||
size_t read_index2 = (j - 1) * stride2 + offset;
|
||||
output[read_index] = output[read_index2] + input[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::LaunchCumSum(const T *input, T *output, T *workspace, size_t start, size_t end) {
|
||||
start = start / dims_[1];
|
||||
end = end / dims_[1];
|
||||
if (exclusive_) {
|
||||
if (reverse_) {
|
||||
RightMove(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
Copy(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
CumSumKernelReverse(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
} else {
|
||||
LeftMove(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
Copy(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
CumSumKernel(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
}
|
||||
} else {
|
||||
if (reverse_) {
|
||||
CumSumKernelReverse(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
} else {
|
||||
CumSumKernel(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CumSumCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto ws = reinterpret_cast<T *>(workspace[0]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(T)) : 1;
|
||||
auto max_thread_num = std::thread::hardware_concurrency();
|
||||
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
|
||||
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(thread_num);
|
||||
size_t start = 0;
|
||||
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
|
||||
if (thread_num < 1 || once_compute_size < 1) {
|
||||
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size;
|
||||
return;
|
||||
}
|
||||
while (start < lens) {
|
||||
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
|
||||
threads.emplace_back(std::thread(&CumSumCPUKernel::LaunchCumSum<T>, this, input, output, ws, start, end));
|
||||
start += once_compute_size;
|
||||
}
|
||||
for (size_t i = 0; i < threads.size(); ++i) {
|
||||
threads[i].join();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void CumSumCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumSumGpuKernel needs 1.";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CumSumCPUKernel : public CPUKernel {
|
||||
public:
|
||||
CumSumCPUKernel() = default;
|
||||
~CumSumCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void InitWorkspaceSize();
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void LaunchCumSum(const T *input_addr, T *output_addr, T *ws_addr, size_t start, size_t end);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
|
||||
void Reshape();
|
||||
|
||||
template <typename T>
|
||||
void LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
size_t start, size_t end);
|
||||
|
||||
template <typename T>
|
||||
void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
size_t start, size_t end);
|
||||
|
||||
template <typename T>
|
||||
void Copy(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, size_t start,
|
||||
size_t end);
|
||||
|
||||
template <typename T>
|
||||
void CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, size_t start, size_t end);
|
||||
|
||||
template <typename T>
|
||||
void CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
size_t start, size_t end);
|
||||
|
||||
std::vector<size_t> shape_;
|
||||
std::vector<size_t> dst_shape;
|
||||
size_t input_size_0_;
|
||||
size_t stride_;
|
||||
size_t stride2_;
|
||||
size_t dims_[3] = {};
|
||||
int exclusive_;
|
||||
int reverse_;
|
||||
int axis_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CumSumCPUKernel);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CumSumCPUKernel);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CumSumCPUKernel);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CumSumCPUKernel);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CumSumCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_
|
|
@ -844,7 +844,7 @@ class CumSum(PrimitiveWithInfer):
|
|||
Tensor, the shape of the output tensor is consistent with the input tensor's.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float32))
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
axis0 = 0
|
||||
axis1 = 1
|
||||
axis2 = 2
|
||||
axis3 = 3
|
||||
axis4 = 4
|
||||
axis5 = -1
|
||||
axis6 = -2
|
||||
|
||||
x0 = np.random.rand(3, 3, 4, 5, 3).astype(np.float32)
|
||||
x1 = np.random.rand(2, 3, 4, 5, 3).astype(np.float16)
|
||||
x2 = np.random.randint(-10000, 10000, size=(2, 3, 4, 5, 3)).astype(np.int32)
|
||||
x3 = np.random.randint(-5, 5, size=(2, 3, 4, 5, 3)).astype(np.int8)
|
||||
x4 = np.random.randint(0, 10, size=(2, 3, 4, 5, 3)).astype(np.uint8)
|
||||
x5 = np.random.rand(3).astype(np.float32)
|
||||
|
||||
list1 = [x0, x1, x2, x3, x4]
|
||||
list2 = [axis0, axis1, axis2, axis3, axis4, axis5, axis6]
|
||||
|
||||
class CumSum(nn.Cell):
|
||||
def __init__(self, exclusive=False, reverse=False):
|
||||
super(CumSum, self).__init__()
|
||||
self.cumsum_op = P.CumSum(exclusive, reverse)
|
||||
|
||||
self.x0 = Tensor(x0)
|
||||
self.axis0 = axis0
|
||||
self.x1 = Tensor(x0)
|
||||
self.axis1 = axis1
|
||||
self.x2 = Tensor(x0)
|
||||
self.axis2 = axis2
|
||||
self.x3 = Tensor(x0)
|
||||
self.axis3 = axis3
|
||||
self.x4 = Tensor(x0)
|
||||
self.axis4 = axis4
|
||||
self.x5 = Tensor(x0)
|
||||
self.axis5 = axis5
|
||||
self.x6 = Tensor(x0)
|
||||
self.axis6 = axis6
|
||||
|
||||
self.x7 = Tensor(x1)
|
||||
self.axis7 = axis0
|
||||
self.x8 = Tensor(x1)
|
||||
self.axis8 = axis1
|
||||
self.x9 = Tensor(x1)
|
||||
self.axis9 = axis2
|
||||
self.x10 = Tensor(x1)
|
||||
self.axis10 = axis3
|
||||
self.x11 = Tensor(x1)
|
||||
self.axis11 = axis4
|
||||
self.x12 = Tensor(x1)
|
||||
self.axis12 = axis5
|
||||
self.x13 = Tensor(x1)
|
||||
self.axis13 = axis6
|
||||
|
||||
self.x14 = Tensor(x2)
|
||||
self.axis14 = axis0
|
||||
self.x15 = Tensor(x2)
|
||||
self.axis15 = axis1
|
||||
self.x16 = Tensor(x2)
|
||||
self.axis16 = axis2
|
||||
self.x17 = Tensor(x2)
|
||||
self.axis17 = axis3
|
||||
self.x18 = Tensor(x2)
|
||||
self.axis18 = axis4
|
||||
self.x19 = Tensor(x2)
|
||||
self.axis19 = axis5
|
||||
self.x20 = Tensor(x2)
|
||||
self.axis20 = axis6
|
||||
|
||||
self.x21 = Tensor(x3)
|
||||
self.axis21 = axis0
|
||||
self.x22 = Tensor(x3)
|
||||
self.axis22 = axis1
|
||||
self.x23 = Tensor(x3)
|
||||
self.axis23 = axis2
|
||||
self.x24 = Tensor(x3)
|
||||
self.axis24 = axis3
|
||||
self.x25 = Tensor(x3)
|
||||
self.axis25 = axis4
|
||||
self.x26 = Tensor(x3)
|
||||
self.axis26 = axis5
|
||||
self.x27 = Tensor(x3)
|
||||
self.axis27 = axis6
|
||||
|
||||
self.x28 = Tensor(x4)
|
||||
self.axis28 = axis0
|
||||
self.x29 = Tensor(x4)
|
||||
self.axis29 = axis1
|
||||
self.x30 = Tensor(x4)
|
||||
self.axis30 = axis2
|
||||
self.x31 = Tensor(x4)
|
||||
self.axis31 = axis3
|
||||
self.x32 = Tensor(x4)
|
||||
self.axis32 = axis4
|
||||
self.x33 = Tensor(x4)
|
||||
self.axis33 = axis5
|
||||
self.x34 = Tensor(x4)
|
||||
self.axis34 = axis6
|
||||
|
||||
self.x35 = Tensor(x5)
|
||||
self.axis35 = axis0
|
||||
|
||||
def construct(self):
|
||||
return (self.cumsum_op(self.x0, self.axis0),
|
||||
self.cumsum_op(self.x1, self.axis1),
|
||||
self.cumsum_op(self.x2, self.axis2),
|
||||
self.cumsum_op(self.x3, self.axis3),
|
||||
self.cumsum_op(self.x4, self.axis4),
|
||||
self.cumsum_op(self.x5, self.axis5),
|
||||
self.cumsum_op(self.x6, self.axis6),
|
||||
self.cumsum_op(self.x7, self.axis7),
|
||||
self.cumsum_op(self.x8, self.axis8),
|
||||
self.cumsum_op(self.x9, self.axis9),
|
||||
self.cumsum_op(self.x10, self.axis10),
|
||||
self.cumsum_op(self.x11, self.axis11),
|
||||
self.cumsum_op(self.x12, self.axis12),
|
||||
self.cumsum_op(self.x13, self.axis13),
|
||||
self.cumsum_op(self.x14, self.axis14),
|
||||
self.cumsum_op(self.x15, self.axis15),
|
||||
self.cumsum_op(self.x16, self.axis16),
|
||||
self.cumsum_op(self.x17, self.axis17),
|
||||
self.cumsum_op(self.x18, self.axis18),
|
||||
self.cumsum_op(self.x19, self.axis19),
|
||||
self.cumsum_op(self.x20, self.axis20),
|
||||
self.cumsum_op(self.x21, self.axis21),
|
||||
self.cumsum_op(self.x22, self.axis22),
|
||||
self.cumsum_op(self.x23, self.axis23),
|
||||
self.cumsum_op(self.x24, self.axis24),
|
||||
self.cumsum_op(self.x25, self.axis25),
|
||||
self.cumsum_op(self.x26, self.axis26),
|
||||
self.cumsum_op(self.x27, self.axis27),
|
||||
self.cumsum_op(self.x28, self.axis28),
|
||||
self.cumsum_op(self.x29, self.axis29),
|
||||
self.cumsum_op(self.x30, self.axis30),
|
||||
self.cumsum_op(self.x31, self.axis31),
|
||||
self.cumsum_op(self.x32, self.axis32),
|
||||
self.cumsum_op(self.x33, self.axis33),
|
||||
self.cumsum_op(self.x34, self.axis34),
|
||||
self.cumsum_op(self.x35, self.axis35))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cumsum():
|
||||
cumsum = CumSum()
|
||||
output = cumsum()
|
||||
|
||||
k = 0
|
||||
|
||||
for i in list1:
|
||||
for j in list2:
|
||||
expect = np.cumsum(i, axis=j)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
||||
k += 1
|
||||
|
||||
expect = np.cumsum(x5, axis=axis0)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
||||
|
||||
|
||||
def test_cumsum2():
|
||||
cumsum = CumSum(exclusive=False, reverse=True)
|
||||
output = cumsum()
|
||||
|
||||
k = 0
|
||||
|
||||
for i in list1:
|
||||
for j in list2:
|
||||
result1 = np.flip(i, axis=j)
|
||||
result2 = np.cumsum(result1, axis=j)
|
||||
expect = np.flip(result2, axis=j)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
||||
k += 1
|
||||
|
||||
result1 = np.flip(x5, axis=axis0)
|
||||
result2 = np.cumsum(result1, axis=axis0)
|
||||
expect = np.flip(result2, axis=axis0)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
||||
|
||||
|
||||
def test_cumsum3():
|
||||
cumsum = CumSum(exclusive=True, reverse=False)
|
||||
output = cumsum()
|
||||
|
||||
k = 0
|
||||
|
||||
for i in list1:
|
||||
for j in list2:
|
||||
result1 = np.insert(i, 0, [0], axis=j)
|
||||
result2 = np.delete(result1, -1, axis=j)
|
||||
expect = np.cumsum(result2, axis=j)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
||||
k += 1
|
||||
|
||||
result1 = np.insert(x5, 0, [0], axis=axis0)
|
||||
result2 = np.delete(result1, -1, axis=axis0)
|
||||
expect = np.cumsum(result2, axis=axis0)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
||||
|
||||
|
||||
def test_cumsum4():
|
||||
cumsum = CumSum(exclusive=True, reverse=True)
|
||||
output = cumsum()
|
||||
|
||||
k = 0
|
||||
|
||||
for i in list1:
|
||||
for j in list2:
|
||||
result1 = np.flip(i, axis=j)
|
||||
result2 = np.insert(result1, 0, [0], axis=j)
|
||||
result3 = np.delete(result2, -1, axis=j)
|
||||
result4 = np.cumsum(result3, axis=j)
|
||||
expect = np.flip(result4, axis=j)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
||||
k += 1
|
||||
|
||||
result1 = np.flip(x5, axis=axis0)
|
||||
result2 = np.insert(result1, 0, [0], axis=axis0)
|
||||
result3 = np.delete(result2, -1, axis=axis0)
|
||||
result4 = np.cumsum(result3, axis=axis0)
|
||||
expect = np.flip(result4, axis=axis0)
|
||||
diff = abs(output[k].asnumpy() - expect)
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < error)
|
||||
assert output[k].shape == expect.shape
|
Loading…
Reference in New Issue