From 4a8724e0ac4d3daccd0d13ca0b34c01e27a0c00f Mon Sep 17 00:00:00 2001 From: baihuawei Date: Thu, 23 Jul 2020 15:43:04 +0800 Subject: [PATCH] add gpu cumsum --- .../gpu/arrays/array_reduce_gpu_kernel.cc | 4 + .../gpu/arrays/array_reduce_gpu_kernel.h | 1 + .../gpu/cuda_impl/cumsum_impl.cu | 50 +++++ .../gpu/cuda_impl/cumsum_impl.cuh | 22 +++ .../gpu/math/cumsum_gpu_kernel.cc | 24 +++ .../gpu/math/cumsum_gpu_kernel.h | 102 ++++++++++ tests/st/ops/gpu/test_cumsum_op.py | 132 +++++++++++++ tests/st/ops/gpu/test_reduce_min_op.py | 177 ++++++++++++++++++ 8 files changed, 512 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_cumsum_op.py create mode 100644 tests/st/ops/gpu/test_reduce_min_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc index 5d34a1c9c2b..3e7cb788ea2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc @@ -30,5 +30,9 @@ MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).A ArrayReduceGpuKernel, float) MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ArrayReduceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h index 94af233d114..0f325199590 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -29,6 +29,7 @@ const std::map kReduceTypeMap = { {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, + {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN}, }; template class ArrayReduceGpuKernel : public GpuKernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu new file mode 100644 index 00000000000..dfc62147b75 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cumsum_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2) { + size_t num = dim0 * dim2; + size_t i, k, offset; + for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; + write_index += blockDim.x * gridDim.x) { + i = write_index / dim2 % dim0; + k = write_index % dim2; + offset = i * stride + k; + for (size_t j = 0; j < dim1; ++j) { + size_t read_index = j * stride2 + offset; + if (j == 0) { + output[read_index] = input[read_index]; + } else { + size_t read_index2 = (j - 1) * stride2 + offset; + output[read_index] = output[read_index2] + input[read_index]; + } + } + } +} +template +void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, + cudaStream_t stream) { + int size = dim0 * dim2; + CumSumKernel<<>>(input, output, dim0, dim1, dim2, stride, stride2); + return; +} + +template void CumSum(float *input, float *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, + size_t stride2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh new file mode 100644 index 00000000000..85ca551643a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ +template +void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, + cudaStream_t stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc new file mode 100644 index 00000000000..deb5e39ff7c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CumSumGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h new file mode 100644 index 00000000000..92e92324162 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class CumSumGpuKernel : public GpuKernel { + public: + CumSumGpuKernel() : axis_(0), input_size_0_(0), stride_(0), stride2_(0) {} + ~CumSumGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + CumSum(input_addr, output_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumSumGpuKernel needs 1."; + return false; + } + input_size_0_ = sizeof(T); + shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + axis_ = GetAttr(kernel_node, "axis"); + int input_dim_length = SizeToInt(shape_.size()); + if (axis_ >= input_dim_length) { + MS_LOG(EXCEPTION) << "Axis out of bounds."; + } + while (axis_ < 0) { + axis_ += input_dim_length; + } + for (size_t i = 0; i < shape_.size(); i++) { + input_size_0_ *= shape_[i]; + } + Reshape(); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_0_); + output_size_list_.push_back(input_size_0_); + } + + private: + void Reshape() { + dims_[0] = 1; + dims_[1] = shape_[IntToSize(axis_)]; + dims_[2] = 1; + for (size_t i = 0; i < IntToSize(axis_); i++) { + dims_[0] *= shape_[i]; + } + for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) { + dims_[2] *= shape_[i]; + } + stride_ = dims_[1] * dims_[2]; + stride2_ = dims_[2]; + return; + } + int axis_; + size_t input_size_0_; + size_t stride_; + size_t stride2_; + size_t dims_[3] = {}; + std::vector shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_cumsum_op.py b/tests/st/ops/gpu/test_cumsum_op.py new file mode 100644 index 00000000000..c639c2952d8 --- /dev/null +++ b/tests/st/ops/gpu/test_cumsum_op.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +x0 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis0 = 3 + +x1 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis1 = 3 + +x2 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis2 = 2 + +x3 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis3 = 2 + +x4 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis4 = 1 + +x5 = np.random.rand(2, 3).astype(np.float32) +axis5 = 1 + +x6 = np.random.rand(1, 1, 1, 1).astype(np.float32) +axis6 = 0 + +context.set_context(device_target='GPU') + + +class CumSum(nn.Cell): + def __init__(self): + super(CumSum, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + + self.x4 = Tensor(x4) + self.axis4 = axis4 + + self.x5 = Tensor(x5) + self.axis5 = axis5 + + self.x6 = Tensor(x6) + self.axis6 = axis6 + + @ms_function + def construct(self): + return (P.CumSum()(self.x0, self.axis0), + P.CumSum()(self.x1, self.axis1), + P.CumSum()(self.x2, self.axis2), + P.CumSum()(self.x3, self.axis3), + P.CumSum()(self.x4, self.axis4), + P.CumSum()(self.x5, self.axis5), + P.CumSum()(self.x6, self.axis6)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_CumSum(): + cumsum = CumSum() + output = cumsum() + + expect0 = np.cumsum(x0, axis=axis0) + diff0 = abs(output[0].asnumpy() - expect0) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0].shape == expect0.shape + + expect1 = np.cumsum(x1, axis=axis1) + diff1 = abs(output[1].asnumpy() - expect1) + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1].shape == expect1.shape + + expect2 = np.cumsum(x2, axis=axis2) + diff2 = abs(output[2].asnumpy() - expect2) + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2].shape == expect2.shape + + expect3 = np.cumsum(x3, axis=axis3) + diff3 = abs(output[3].asnumpy() - expect3) + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output[3].shape == expect3.shape + + expect4 = np.cumsum(x4, axis=axis4) + diff4 = abs(output[4].asnumpy() - expect4) + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output[4].shape == expect4.shape + + expect5 = np.cumsum(x5, axis=axis5) + diff5 = abs(output[5].asnumpy() - expect5) + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output[5].shape == expect5.shape + + expect6 = np.cumsum(x6, axis=axis6) + diff6 = abs(output[6].asnumpy() - expect6) + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output[6].shape == expect6.shape diff --git a/tests/st/ops/gpu/test_reduce_min_op.py b/tests/st/ops/gpu/test_reduce_min_op.py new file mode 100644 index 00000000000..c502c549ae1 --- /dev/null +++ b/tests/st/ops/gpu/test_reduce_min_op.py @@ -0,0 +1,177 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +x0 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis0 = 3 +keep_dims0 = True + +x1 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis1 = 3 +keep_dims1 = False + +x2 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis2 = 2 +keep_dims2 = True + +x3 = np.random.rand(2, 3, 1, 4).astype(np.float32) +axis3 = 2 +keep_dims3 = False + +x4 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis4 = () +np_axis4 = None +keep_dims4 = True + +x5 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis5 = () +np_axis5 = None +keep_dims5 = False + +x6 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis6 = -2 +keep_dims6 = False + +x7 = np.random.rand(2, 3, 4, 4).astype(np.float32) +axis7 = (-2, -1) +keep_dims7 = True + +x8 = np.random.rand(1, 1, 1, 1).astype(np.float32) +axis8 = () +np_axis8 = None +keep_dims8 = True + +context.set_context(device_target='GPU') + + +class ReduceMin(nn.Cell): + def __init__(self): + super(ReduceMin, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + self.keep_dims0 = keep_dims0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + self.keep_dims1 = keep_dims1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + self.keep_dims2 = keep_dims2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + self.keep_dims3 = keep_dims3 + + self.x4 = Tensor(x4) + self.axis4 = axis4 + self.keep_dims4 = keep_dims4 + + self.x5 = Tensor(x5) + self.axis5 = axis5 + self.keep_dims5 = keep_dims5 + + self.x6 = Tensor(x6) + self.axis6 = axis6 + self.keep_dims6 = keep_dims6 + + self.x7 = Tensor(x7) + self.axis7 = axis7 + self.keep_dims7 = keep_dims7 + + self.x8 = Tensor(x8) + self.axis8 = axis8 + self.keep_dims8 = keep_dims8 + + @ms_function + def construct(self): + return (P.ReduceMin(self.keep_dims0)(self.x0, self.axis0), + P.ReduceMin(self.keep_dims1)(self.x1, self.axis1), + P.ReduceMin(self.keep_dims2)(self.x2, self.axis2), + P.ReduceMin(self.keep_dims3)(self.x3, self.axis3), + P.ReduceMin(self.keep_dims4)(self.x4, self.axis4), + P.ReduceMin(self.keep_dims5)(self.x5, self.axis5), + P.ReduceMin(self.keep_dims6)(self.x6, self.axis6), + P.ReduceMin(self.keep_dims7)(self.x7, self.axis7), + P.ReduceMin(self.keep_dims8)(self.x8, self.axis8)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ReduceMin(): + reduce_min = ReduceMin() + output = reduce_min() + + expect0 = np.min(x0, axis=axis0, keepdims=keep_dims0) + diff0 = abs(output[0].asnumpy() - expect0) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output[0].shape == expect0.shape + + expect1 = np.min(x1, axis=axis1, keepdims=keep_dims1) + diff1 = abs(output[1].asnumpy() - expect1) + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output[1].shape == expect1.shape + + expect2 = np.min(x2, axis=axis2, keepdims=keep_dims2) + diff2 = abs(output[2].asnumpy() - expect2) + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output[2].shape == expect2.shape + + expect3 = np.min(x3, axis=axis3, keepdims=keep_dims3) + diff3 = abs(output[3].asnumpy() - expect3) + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output[3].shape == expect3.shape + + expect4 = np.min(x4, axis=np_axis4, keepdims=keep_dims4) + diff4 = abs(output[4].asnumpy() - expect4) + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output[4].shape == expect4.shape + + expect5 = np.min(x5, axis=np_axis5, keepdims=keep_dims5) + diff5 = abs(output[5].asnumpy() - expect5) + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output[5].shape == expect5.shape + + expect6 = np.min(x6, axis=axis6, keepdims=keep_dims6) + diff6 = abs(output[6].asnumpy() - expect6) + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output[6].shape == expect6.shape + + expect7 = np.min(x7, axis=axis7, keepdims=keep_dims7) + diff7 = abs(output[7].asnumpy() - expect7) + error7 = np.ones(shape=expect7.shape) * 1.0e-5 + assert np.all(diff7 < error7) + + expect8 = np.min(x8, axis=np_axis8, keepdims=keep_dims8) + diff8 = abs(output[8].asnumpy() - expect8) + error8 = np.ones(shape=expect8.shape) * 1.0e-5 + assert np.all(diff8 < error8)