forked from mindspore-Ecosystem/mindspore
add gpu cumsum
This commit is contained in:
parent
b5d8dad47d
commit
4a8724e0ac
|
@ -30,5 +30,9 @@ MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).A
|
|||
ArrayReduceGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArrayReduceGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArrayReduceGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArrayReduceGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,6 +29,7 @@ const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = {
|
|||
{"ReduceMax", CUDNN_REDUCE_TENSOR_MAX},
|
||||
{"ReduceMean", CUDNN_REDUCE_TENSOR_AVG},
|
||||
{"ReduceSum", CUDNN_REDUCE_TENSOR_ADD},
|
||||
{"ReduceMin", CUDNN_REDUCE_TENSOR_MIN},
|
||||
};
|
||||
template <typename T>
|
||||
class ArrayReduceGpuKernel : public GpuKernel {
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cumsum_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num;
|
||||
write_index += blockDim.x * gridDim.x) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == 0) {
|
||||
output[read_index] = input[read_index];
|
||||
} else {
|
||||
size_t read_index2 = (j - 1) * stride2 + offset;
|
||||
output[read_index] = output[read_index2] + input[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
cudaStream_t stream) {
|
||||
int size = dim0 * dim2;
|
||||
CumSumKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CumSum<float>(float *input, float *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, cudaStream_t stream);
|
|
@ -0,0 +1,22 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
||||
template <typename T>
|
||||
void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CumSumGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CumSumGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CumSumGpuKernel() : axis_(0), input_size_0_(0), stride_(0), stride2_(0) {}
|
||||
~CumSumGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
CumSum(input_addr, output_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumSumGpuKernel needs 1.";
|
||||
return false;
|
||||
}
|
||||
input_size_0_ = sizeof(T);
|
||||
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
axis_ = GetAttr<int>(kernel_node, "axis");
|
||||
int input_dim_length = SizeToInt(shape_.size());
|
||||
if (axis_ >= input_dim_length) {
|
||||
MS_LOG(EXCEPTION) << "Axis out of bounds.";
|
||||
}
|
||||
while (axis_ < 0) {
|
||||
axis_ += input_dim_length;
|
||||
}
|
||||
for (size_t i = 0; i < shape_.size(); i++) {
|
||||
input_size_0_ *= shape_[i];
|
||||
}
|
||||
Reshape();
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_0_);
|
||||
output_size_list_.push_back(input_size_0_);
|
||||
}
|
||||
|
||||
private:
|
||||
void Reshape() {
|
||||
dims_[0] = 1;
|
||||
dims_[1] = shape_[IntToSize(axis_)];
|
||||
dims_[2] = 1;
|
||||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
||||
dims_[0] *= shape_[i];
|
||||
}
|
||||
for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) {
|
||||
dims_[2] *= shape_[i];
|
||||
}
|
||||
stride_ = dims_[1] * dims_[2];
|
||||
stride2_ = dims_[2];
|
||||
return;
|
||||
}
|
||||
int axis_;
|
||||
size_t input_size_0_;
|
||||
size_t stride_;
|
||||
size_t stride2_;
|
||||
size_t dims_[3] = {};
|
||||
std::vector<size_t> shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
x0 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis0 = 3
|
||||
|
||||
x1 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis1 = 3
|
||||
|
||||
x2 = np.random.rand(2, 3, 1, 4).astype(np.float32)
|
||||
axis2 = 2
|
||||
|
||||
x3 = np.random.rand(2, 3, 1, 4).astype(np.float32)
|
||||
axis3 = 2
|
||||
|
||||
x4 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis4 = 1
|
||||
|
||||
x5 = np.random.rand(2, 3).astype(np.float32)
|
||||
axis5 = 1
|
||||
|
||||
x6 = np.random.rand(1, 1, 1, 1).astype(np.float32)
|
||||
axis6 = 0
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
|
||||
class CumSum(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CumSum, self).__init__()
|
||||
|
||||
self.x0 = Tensor(x0)
|
||||
self.axis0 = axis0
|
||||
|
||||
self.x1 = Tensor(x1)
|
||||
self.axis1 = axis1
|
||||
|
||||
self.x2 = Tensor(x2)
|
||||
self.axis2 = axis2
|
||||
|
||||
self.x3 = Tensor(x3)
|
||||
self.axis3 = axis3
|
||||
|
||||
self.x4 = Tensor(x4)
|
||||
self.axis4 = axis4
|
||||
|
||||
self.x5 = Tensor(x5)
|
||||
self.axis5 = axis5
|
||||
|
||||
self.x6 = Tensor(x6)
|
||||
self.axis6 = axis6
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return (P.CumSum()(self.x0, self.axis0),
|
||||
P.CumSum()(self.x1, self.axis1),
|
||||
P.CumSum()(self.x2, self.axis2),
|
||||
P.CumSum()(self.x3, self.axis3),
|
||||
P.CumSum()(self.x4, self.axis4),
|
||||
P.CumSum()(self.x5, self.axis5),
|
||||
P.CumSum()(self.x6, self.axis6))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_CumSum():
|
||||
cumsum = CumSum()
|
||||
output = cumsum()
|
||||
|
||||
expect0 = np.cumsum(x0, axis=axis0)
|
||||
diff0 = abs(output[0].asnumpy() - expect0)
|
||||
error0 = np.ones(shape=expect0.shape) * 1.0e-5
|
||||
assert np.all(diff0 < error0)
|
||||
assert output[0].shape == expect0.shape
|
||||
|
||||
expect1 = np.cumsum(x1, axis=axis1)
|
||||
diff1 = abs(output[1].asnumpy() - expect1)
|
||||
error1 = np.ones(shape=expect1.shape) * 1.0e-5
|
||||
assert np.all(diff1 < error1)
|
||||
assert output[1].shape == expect1.shape
|
||||
|
||||
expect2 = np.cumsum(x2, axis=axis2)
|
||||
diff2 = abs(output[2].asnumpy() - expect2)
|
||||
error2 = np.ones(shape=expect2.shape) * 1.0e-5
|
||||
assert np.all(diff2 < error2)
|
||||
assert output[2].shape == expect2.shape
|
||||
|
||||
expect3 = np.cumsum(x3, axis=axis3)
|
||||
diff3 = abs(output[3].asnumpy() - expect3)
|
||||
error3 = np.ones(shape=expect3.shape) * 1.0e-5
|
||||
assert np.all(diff3 < error3)
|
||||
assert output[3].shape == expect3.shape
|
||||
|
||||
expect4 = np.cumsum(x4, axis=axis4)
|
||||
diff4 = abs(output[4].asnumpy() - expect4)
|
||||
error4 = np.ones(shape=expect4.shape) * 1.0e-5
|
||||
assert np.all(diff4 < error4)
|
||||
assert output[4].shape == expect4.shape
|
||||
|
||||
expect5 = np.cumsum(x5, axis=axis5)
|
||||
diff5 = abs(output[5].asnumpy() - expect5)
|
||||
error5 = np.ones(shape=expect5.shape) * 1.0e-5
|
||||
assert np.all(diff5 < error5)
|
||||
assert output[5].shape == expect5.shape
|
||||
|
||||
expect6 = np.cumsum(x6, axis=axis6)
|
||||
diff6 = abs(output[6].asnumpy() - expect6)
|
||||
error6 = np.ones(shape=expect6.shape) * 1.0e-5
|
||||
assert np.all(diff6 < error6)
|
||||
assert output[6].shape == expect6.shape
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
x0 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis0 = 3
|
||||
keep_dims0 = True
|
||||
|
||||
x1 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis1 = 3
|
||||
keep_dims1 = False
|
||||
|
||||
x2 = np.random.rand(2, 3, 1, 4).astype(np.float32)
|
||||
axis2 = 2
|
||||
keep_dims2 = True
|
||||
|
||||
x3 = np.random.rand(2, 3, 1, 4).astype(np.float32)
|
||||
axis3 = 2
|
||||
keep_dims3 = False
|
||||
|
||||
x4 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis4 = ()
|
||||
np_axis4 = None
|
||||
keep_dims4 = True
|
||||
|
||||
x5 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis5 = ()
|
||||
np_axis5 = None
|
||||
keep_dims5 = False
|
||||
|
||||
x6 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis6 = -2
|
||||
keep_dims6 = False
|
||||
|
||||
x7 = np.random.rand(2, 3, 4, 4).astype(np.float32)
|
||||
axis7 = (-2, -1)
|
||||
keep_dims7 = True
|
||||
|
||||
x8 = np.random.rand(1, 1, 1, 1).astype(np.float32)
|
||||
axis8 = ()
|
||||
np_axis8 = None
|
||||
keep_dims8 = True
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
|
||||
class ReduceMin(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ReduceMin, self).__init__()
|
||||
|
||||
self.x0 = Tensor(x0)
|
||||
self.axis0 = axis0
|
||||
self.keep_dims0 = keep_dims0
|
||||
|
||||
self.x1 = Tensor(x1)
|
||||
self.axis1 = axis1
|
||||
self.keep_dims1 = keep_dims1
|
||||
|
||||
self.x2 = Tensor(x2)
|
||||
self.axis2 = axis2
|
||||
self.keep_dims2 = keep_dims2
|
||||
|
||||
self.x3 = Tensor(x3)
|
||||
self.axis3 = axis3
|
||||
self.keep_dims3 = keep_dims3
|
||||
|
||||
self.x4 = Tensor(x4)
|
||||
self.axis4 = axis4
|
||||
self.keep_dims4 = keep_dims4
|
||||
|
||||
self.x5 = Tensor(x5)
|
||||
self.axis5 = axis5
|
||||
self.keep_dims5 = keep_dims5
|
||||
|
||||
self.x6 = Tensor(x6)
|
||||
self.axis6 = axis6
|
||||
self.keep_dims6 = keep_dims6
|
||||
|
||||
self.x7 = Tensor(x7)
|
||||
self.axis7 = axis7
|
||||
self.keep_dims7 = keep_dims7
|
||||
|
||||
self.x8 = Tensor(x8)
|
||||
self.axis8 = axis8
|
||||
self.keep_dims8 = keep_dims8
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return (P.ReduceMin(self.keep_dims0)(self.x0, self.axis0),
|
||||
P.ReduceMin(self.keep_dims1)(self.x1, self.axis1),
|
||||
P.ReduceMin(self.keep_dims2)(self.x2, self.axis2),
|
||||
P.ReduceMin(self.keep_dims3)(self.x3, self.axis3),
|
||||
P.ReduceMin(self.keep_dims4)(self.x4, self.axis4),
|
||||
P.ReduceMin(self.keep_dims5)(self.x5, self.axis5),
|
||||
P.ReduceMin(self.keep_dims6)(self.x6, self.axis6),
|
||||
P.ReduceMin(self.keep_dims7)(self.x7, self.axis7),
|
||||
P.ReduceMin(self.keep_dims8)(self.x8, self.axis8))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ReduceMin():
|
||||
reduce_min = ReduceMin()
|
||||
output = reduce_min()
|
||||
|
||||
expect0 = np.min(x0, axis=axis0, keepdims=keep_dims0)
|
||||
diff0 = abs(output[0].asnumpy() - expect0)
|
||||
error0 = np.ones(shape=expect0.shape) * 1.0e-5
|
||||
assert np.all(diff0 < error0)
|
||||
assert output[0].shape == expect0.shape
|
||||
|
||||
expect1 = np.min(x1, axis=axis1, keepdims=keep_dims1)
|
||||
diff1 = abs(output[1].asnumpy() - expect1)
|
||||
error1 = np.ones(shape=expect1.shape) * 1.0e-5
|
||||
assert np.all(diff1 < error1)
|
||||
assert output[1].shape == expect1.shape
|
||||
|
||||
expect2 = np.min(x2, axis=axis2, keepdims=keep_dims2)
|
||||
diff2 = abs(output[2].asnumpy() - expect2)
|
||||
error2 = np.ones(shape=expect2.shape) * 1.0e-5
|
||||
assert np.all(diff2 < error2)
|
||||
assert output[2].shape == expect2.shape
|
||||
|
||||
expect3 = np.min(x3, axis=axis3, keepdims=keep_dims3)
|
||||
diff3 = abs(output[3].asnumpy() - expect3)
|
||||
error3 = np.ones(shape=expect3.shape) * 1.0e-5
|
||||
assert np.all(diff3 < error3)
|
||||
assert output[3].shape == expect3.shape
|
||||
|
||||
expect4 = np.min(x4, axis=np_axis4, keepdims=keep_dims4)
|
||||
diff4 = abs(output[4].asnumpy() - expect4)
|
||||
error4 = np.ones(shape=expect4.shape) * 1.0e-5
|
||||
assert np.all(diff4 < error4)
|
||||
assert output[4].shape == expect4.shape
|
||||
|
||||
expect5 = np.min(x5, axis=np_axis5, keepdims=keep_dims5)
|
||||
diff5 = abs(output[5].asnumpy() - expect5)
|
||||
error5 = np.ones(shape=expect5.shape) * 1.0e-5
|
||||
assert np.all(diff5 < error5)
|
||||
assert output[5].shape == expect5.shape
|
||||
|
||||
expect6 = np.min(x6, axis=axis6, keepdims=keep_dims6)
|
||||
diff6 = abs(output[6].asnumpy() - expect6)
|
||||
error6 = np.ones(shape=expect6.shape) * 1.0e-5
|
||||
assert np.all(diff6 < error6)
|
||||
assert output[6].shape == expect6.shape
|
||||
|
||||
expect7 = np.min(x7, axis=axis7, keepdims=keep_dims7)
|
||||
diff7 = abs(output[7].asnumpy() - expect7)
|
||||
error7 = np.ones(shape=expect7.shape) * 1.0e-5
|
||||
assert np.all(diff7 < error7)
|
||||
|
||||
expect8 = np.min(x8, axis=np_axis8, keepdims=keep_dims8)
|
||||
diff8 = abs(output[8].asnumpy() - expect8)
|
||||
error8 = np.ones(shape=expect8.shape) * 1.0e-5
|
||||
assert np.all(diff8 < error8)
|
Loading…
Reference in New Issue