From 652ab6c386b8afb54e542f6ba1d9865e5fb955d9 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Mon, 13 Apr 2020 15:39:51 +0800 Subject: [PATCH] add test case for aware quantizaiton --- .../gpu/quant/batchnorm_fold2_gpu_kernel.h | 14 +-- .../quant/batchnorm_fold2_grad_gpu_kernel.h | 14 +-- .../gpu/quant/batchnorm_fold_gpu_kernel.h | 14 +-- .../quant/batchnorm_fold_grad_gpu_kernel.h | 14 ++- .../gpu/quant/correction_mul_gpu_kernel.h | 14 +-- .../quant/correction_mul_grad_gpu_kernel.h | 16 +-- mindspore/nn/layer/activation.py | 2 +- mindspore/ops/operations/nn_ops.py | 2 +- tests/st/ops/gpu/test_batchnorm_fold2_op.py | 89 ++++++++++++++ .../st/ops/gpu/test_batchnorm_fold_grad_op.py | 96 +++++++++++++++ tests/st/ops/gpu/test_batchnorm_fold_op.py | 116 ++++++++++++++++++ tests/st/ops/gpu/test_conv2d_op.py | 2 +- .../st/ops/gpu/test_correction_mul_grad_op.py | 55 +++++++++ tests/st/ops/gpu/test_correction_mul_op.py | 52 ++++++++ 14 files changed, 456 insertions(+), 44 deletions(-) create mode 100644 tests/st/ops/gpu/test_batchnorm_fold2_op.py create mode 100644 tests/st/ops/gpu/test_batchnorm_fold_grad_op.py create mode 100644 tests/st/ops/gpu/test_batchnorm_fold_op.py create mode 100644 tests/st/ops/gpu/test_correction_mul_grad_op.py create mode 100644 tests/st/ops/gpu/test_correction_mul_op.py diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h index ada4eabd862..3e246f18f64 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h @@ -38,14 +38,14 @@ class BatchNormFold2GpuKernel : public GpuKernel { ~BatchNormFold2GpuKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const { return input_size_list_; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const { return output_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const { return workspace_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uintptr_t stream_ptr) { + const std::vector &outputs, uintptr_t stream_ptr) override { if (is_null_input_) { return true; } @@ -66,7 +66,7 @@ class BatchNormFold2GpuKernel : public GpuKernel { return true; } - bool Init(const CNodePtr &kernel_node) { + bool Init(const CNodePtr &kernel_node) override { InitResource(); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); @@ -98,9 +98,9 @@ class BatchNormFold2GpuKernel : public GpuKernel { } protected: - void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - void InitSizeLists() { + void InitSizeLists() override { size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); size_t weight_size = channel_ * sizeof(T); input_size_list_.push_back(input_size); diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h index ef9611f2586..099960e7fae 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h @@ -38,14 +38,14 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { ~BatchNormFold2GradGpuKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const { return input_size_list_; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const { return output_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const { return workspace_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uintptr_t stream_ptr) { + const std::vector &outputs, uintptr_t stream_ptr) override { if (is_null_input_) { return true; } @@ -88,7 +88,7 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { return true; } - bool Init(const CNodePtr &kernel_node) { + bool Init(const CNodePtr &kernel_node) override { InitResource(); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); @@ -120,9 +120,9 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { } protected: - void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - void InitSizeLists() { + void InitSizeLists() override { size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); size_t weight_size = channel_ * sizeof(T); size_t workspace_size = batch_size_ * channel_ * sizeof(T); diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h index e90fac27920..3e8c1ca52b6 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h @@ -46,14 +46,14 @@ class BatchNormFoldGpuKernel : public GpuKernel { ~BatchNormFoldGpuKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const { return input_size_list_; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const { return output_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const { return workspace_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uintptr_t stream_ptr) { + const std::vector &outputs, uintptr_t stream_ptr) override { (void)workspace; auto x = reinterpret_cast(inputs[0]->addr); auto mean = reinterpret_cast(inputs[1]->addr); @@ -104,7 +104,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { return true; } - bool Init(const CNodePtr &kernel_node) { + bool Init(const CNodePtr &kernel_node) override { InitResource(); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 4) { @@ -152,7 +152,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { } protected: - void InitSizeLists() { + void InitSizeLists() override { // x, mean, variance, current_step input_size_list_.push_back(input_size_); input_size_list_.push_back(output_size_); @@ -169,7 +169,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { workspace_size_list_.push_back(input_size_); } - void InitResource() { + void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h index 830f6dc2438..ec845fbb9e9 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h @@ -42,11 +42,12 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { width_(0) {} ~BatchNormFoldGradGpuKernel() = default; - const std::vector &GetInputSizeList() const { return input_size_list_; } - const std::vector &GetOutputSizeList() const { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const { return workspace_size_list_; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uintptr_t stream_ptr) { + const std::vector &outputs, uintptr_t stream_ptr) override { (void)workspace; // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' T *d_batch_mean = GetDeviceAddress(inputs, 0); @@ -92,7 +93,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); return true; } - bool Init(const CNodePtr &kernel_node) { + + bool Init(const CNodePtr &kernel_node) override { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 6) { MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input."; @@ -128,7 +130,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { } protected: - void InitSizeLists() { + void InitSizeLists() override { // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' input_size_list_.push_back(channel_size_); input_size_list_.push_back(channel_size_); diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h index af23d7732a5..7608ae5d3c5 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h @@ -30,11 +30,11 @@ class CorrectionMulGpuKernel : public GpuKernel { CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} ~CorrectionMulGpuKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const { return input_size_list_; } - const std::vector &GetOutputSizeList() const { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const { return workspace_size_list_; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uintptr_t stream_ptr) { + const std::vector &outputs, uintptr_t stream_ptr) override { auto *weight = GetDeviceAddress(inputs, 0); auto *gamma = GetDeviceAddress(inputs, 1); auto *running_std = GetDeviceAddress(inputs, 2); @@ -44,7 +44,7 @@ class CorrectionMulGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); return true; } - bool Init(const CNodePtr &kernel_node) { + bool Init(const CNodePtr &kernel_node) override { InitResource(); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); @@ -69,7 +69,7 @@ class CorrectionMulGpuKernel : public GpuKernel { } protected: - void InitSizeLists() { + void InitSizeLists() override { size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); size_t weight_size = batch_size_ * sizeof(T); input_size_list_.push_back(input_size); // weight @@ -79,7 +79,7 @@ class CorrectionMulGpuKernel : public GpuKernel { output_size_list_.push_back(input_size); workspace_size_list_.push_back(workspace_size); } - void InitResource() {} + void InitResource() override {} private: void DestroyResource() noexcept {} diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h index f20c6278c05..2439826cc39 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h @@ -30,11 +30,12 @@ class CorrectionMulGradGpuKernel : public GpuKernel { CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} ~CorrectionMulGradGpuKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const { return input_size_list_; } - const std::vector &GetOutputSizeList() const { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const { return workspace_size_list_; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uintptr_t stream_ptr) { + const std::vector &outputs, uintptr_t stream_ptr) override { auto *d_out = GetDeviceAddress(inputs, 0); auto *weight = GetDeviceAddress(inputs, 1); auto *gamma = GetDeviceAddress(inputs, 2); @@ -49,7 +50,8 @@ class CorrectionMulGradGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); return true; } - bool Init(const CNodePtr &kernel_node) { + + bool Init(const CNodePtr &kernel_node) override { InitResource(); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); @@ -74,7 +76,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel { } protected: - void InitSizeLists() { + void InitSizeLists() override { size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); size_t weight_size = batch_size_ * sizeof(T); input_size_list_.push_back(input_size); // d_out @@ -85,7 +87,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel { output_size_list_.push_back(weight_size); // d_gamma workspace_size_list_.push_back(input_size); // tmp d_out * weight } - void InitResource() {} + void InitResource() override {} private: void DestroyResource() noexcept {} diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 12d6c74dcd4..6485e27228d 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -369,7 +369,7 @@ class HSigmoid(Cell): Hard sigmoid is defined as: .. math:: - \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), + \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})), where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e82f59a05d0..0dab6b28a66 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -319,7 +319,7 @@ class HSigmoid(PrimitiveWithInfer): Hard sigmoid is defined as: .. math:: - \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), + \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})), where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. diff --git a/tests/st/ops/gpu/test_batchnorm_fold2_op.py b/tests/st/ops/gpu/test_batchnorm_fold2_op.py new file mode 100644 index 00000000000..0440e92a8dc --- /dev/null +++ b/tests/st/ops/gpu/test_batchnorm_fold2_op.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import mindspore.context as context + +context.set_context(device_target='GPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.op = P.BatchNormFold2(100000) + + @ms_function + def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step): + return self.op(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step) + + +class Net_gnd(nn.Cell): + def __init__(self): + super(Net_gnd, self).__init__() + self.conv_mul = P.ConvMul(freeze_bn=100000) + self.correct_add = P.CorrectionAdd(freeze_bn=100000) + self.add_fold = P.AddFold() + + @ms_function + def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step): + out = self.conv_mul(x, batch_std, running_std, current_step) + out = self.correct_add(out, gamma, batch_std, batch_mean, + running_std, running_mean, current_step) + out = self.add_fold(out, beta, gamma, batch_std, batch_mean) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchnrom_fold2(): + net = Net() + c = 64 + freeze_bn = 100000 + x = np.random.uniform(-1, 1, size=[3, c, 32, 32]).astype('float32') + beta = np.random.uniform(1, 2, size=[c]).astype('float32') + gamma = np.random.uniform(1, 2, size=[c]).astype('float32') + batch_std = np.random.uniform(1, 2, size=[c]).astype('float32') + batch_mean = np.random.uniform(1, 2, size=[c]).astype('float32') + running_std = np.random.uniform(1, 2, size=[c]).astype('float32') + running_mean = np.random.uniform(1, 2, size=[c]).astype('float32') + current_step = np.array([0]).astype('int32') + output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean), + Tensor(running_std), Tensor(running_mean), Tensor(current_step)) + expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1, + 1) if current_step >= freeze_bn else + x * (running_std / batch_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1, + 1)) + error = np.ones(shape=expect.shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(diff > error * -1) + + current_step = np.array([100000]).astype('int32') + output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean), Tensor(running_std), + Tensor(running_mean), Tensor(current_step)) + expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1, + 1) if current_step >= freeze_bn else + x * (batch_std / running_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1, + 1)) + error = np.ones(shape=expect.shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(diff > error * -1) diff --git a/tests/st/ops/gpu/test_batchnorm_fold_grad_op.py b/tests/st/ops/gpu/test_batchnorm_fold_grad_op.py new file mode 100644 index 00000000000..8e55f6a473f --- /dev/null +++ b/tests/st/ops/gpu/test_batchnorm_fold_grad_op.py @@ -0,0 +1,96 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import mindspore.context as context + +context.set_context(device_target='GPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.op = P.BatchNormFoldGrad(freeze_bn=10) + + @ms_function + def construct(self, d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step): + dx = self.op(d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step) + return dx + + +def np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std): + n = x.shape[0] * x.shape[2] * x.shape[3] + dx = d_batch_mean.reshape(1, -1, 1, 1) / n + d_batch_std.reshape(1, -1, 1, 1) * ( + x - batch_mean.reshape(1, -1, 1, 1)) / batch_std.reshape(1, -1, 1, 1) / n + return dx + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchnorm_fold_grad1(): + net = Net() + c = 64 + x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') + d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') + d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') + batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') + batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') + current_step = np.array([0]).astype('int32') + dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std), + Tensor(current_step)) + expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std) + assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchnorm_fold_grad2(): + net = Net() + c = 64 + x = np.random.uniform(1, 10, size=[1, c, 256, 256]).astype('float32') + d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') + d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') + batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') + batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') + current_step = np.array([0]).astype('int32') + dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std), + Tensor(current_step)) + expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std) + assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchnorm_fold_grad_freeze(): + net = Net() + c = 64 + x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') + d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') + d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') + batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32') + batch_std = np.random.uniform(1, 10, size=[c]).astype('float32') + current_step = np.array([10]).astype('int32') + dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std), + Tensor(current_step)) + expect = np.zeros_like(x) + assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7) diff --git a/tests/st/ops/gpu/test_batchnorm_fold_op.py b/tests/st/ops/gpu/test_batchnorm_fold_op.py new file mode 100644 index 00000000000..c4abf152a62 --- /dev/null +++ b/tests/st/ops/gpu/test_batchnorm_fold_op.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import mindspore.context as context + +context.set_context(device_target='GPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.op = P.BatchNormFold(freeze_bn=10) + + @ms_function + def construct(self, x, mean, variance, current_step): + a, b, c, d = self.op(x, mean, variance, current_step) + return a, b, c, d + + +def np_result(x, mean, var, momentum, epsilon): + np_mean = x.mean(axis=(0, 2, 3)) + np_var = x.var(axis=(0, 2, 3)) + n = x.shape[0] * x.shape[2] * x.shape[3] + mean_update = momentum * np_mean + (1 - momentum) * mean + var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var + np_var = np.sqrt(np_var + epsilon) + delay_mean = mean.copy() + delay_std = np.sqrt(var + epsilon) + return np_mean, np_var, mean_update, var_update, delay_mean, delay_std + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchnorm_fold(): + net = Net() + c = 64 + x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') + mean = np.random.uniform(1, 10, size=[c]).astype('float32') + variance = np.random.uniform(1, 10, size=[c]).astype('float32') + current_step = np.array([0]).astype('int32') + ms_mean = Tensor(mean) + ms_var = Tensor(variance) + batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var, + Tensor(current_step)) + + expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12) + assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5) + assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5) + assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5) + assert np.allclose(ms_var.asnumpy(), expect4, rtol=1.e-7, atol=1.e-5) + assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5) + assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchnorm_fold2(): + net = Net() + c = 64 + x = np.random.uniform(1, 10, size=[3, c, 512, 512]).astype('float32') + mean = np.random.uniform(1, 10, size=[c]).astype('float32') + variance = np.random.uniform(1, 10, size=[c]).astype('float32') + current_step = np.array([0]).astype('int32') + ms_mean = Tensor(mean) + ms_var = Tensor(variance) + batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var, + Tensor(current_step)) + expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12) + assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5) + assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5) + assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5) + assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5) + assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchnorm_fold_freeze(): + net = Net() + c = 64 + x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32') + mean = np.random.uniform(1, 10, size=[c]).astype('float32') + variance = np.random.uniform(1, 10, size=[c]).astype('float32') + current_step = np.array([10]).astype('int32') + ms_mean = Tensor(mean) + ms_var = Tensor(variance) + batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var, + Tensor(current_step)) + expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12) + assert np.allclose(batch_mean.asnumpy(), np.zeros_like(mean), rtol=1.e-7, atol=1.e-5) + assert np.allclose(batch_var.asnumpy(), np.ones_like(mean), rtol=1.e-7, atol=1.e-5) + assert np.allclose(ms_mean.asnumpy(), mean, rtol=1.e-7, atol=1.e-5) + assert np.allclose(ms_var.asnumpy(), variance, rtol=1.e-7, atol=1.e-5) + assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5) + assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5) diff --git a/tests/st/ops/gpu/test_conv2d_op.py b/tests/st/ops/gpu/test_conv2d_op.py index d724f6f6c84..1bac156c373 100644 --- a/tests/st/ops/gpu/test_conv2d_op.py +++ b/tests/st/ops/gpu/test_conv2d_op.py @@ -14,10 +14,10 @@ # ============================================================================ import pytest +import numpy as np from mindspore import Tensor from mindspore.ops import operations as P import mindspore.nn as nn -import numpy as np import mindspore.context as context diff --git a/tests/st/ops/gpu/test_correction_mul_grad_op.py b/tests/st/ops/gpu/test_correction_mul_grad_op.py new file mode 100644 index 00000000000..88b391a77a6 --- /dev/null +++ b/tests/st/ops/gpu/test_correction_mul_grad_op.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import os +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import mindspore.context as context + + +context.set_context(device_target='GPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.op_w = P.CorrectionMulGrad() + + @ms_function + def construct(self, dy, x, batch_std, running_std): + dx, d_batch_std = self.op_w(dy, x, batch_std, running_std) + return dx, d_batch_std + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_correction_mul_grad(): + net = Net() + co, ci, h, w = 64, 1, 32, 32 + dout = np.random.uniform(-0.1, 0.1, size=[co, ci, h, w]).astype('float32') + x = np.random.uniform(1, 1, size=[co, ci, h, w]).astype('float32') + batch_std = np.random.uniform(1, 10, size=[co]).astype('float32') + running_std = np.random.uniform(1, 10, size=[co]).astype('float32') + output = net(Tensor(dout), Tensor(x), Tensor(batch_std), Tensor(running_std)) + expect = [0, 0] + expect[0] = (dout * np.reshape(batch_std / running_std, (co, 1, 1, 1))) + expect[1] = (np.sum(dout * x, (1, 2, 3)) / running_std) + for i, v in enumerate(output): + assert (np.allclose(output[i].asnumpy(), expect[i], rtol=1.e-5, atol=1.e-5)) diff --git a/tests/st/ops/gpu/test_correction_mul_op.py b/tests/st/ops/gpu/test_correction_mul_op.py new file mode 100644 index 00000000000..01389e148cc --- /dev/null +++ b/tests/st/ops/gpu/test_correction_mul_op.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import mindspore.context as context + +context.set_context(device_target='GPU') + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.op = P.CorrectionMul() + + @ms_function + def construct(self, x, batch_var, moving_var): + return self.op(x, batch_var, moving_var) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_correction_mul(): + net = Net() + co = 64 + x = np.random.uniform(-1, 1, size=[co, 64, 32, 32]).astype('float32') + bv = np.random.uniform(1, 2, size=[co]).astype('float32') + mv = np.random.uniform(1, 2, size=[co]).astype('float32') + output = net(Tensor(x), Tensor(bv), Tensor(mv)) + expect = x * np.reshape(bv, (co, 1, 1, 1)) / np.reshape(mv, (co, 1, 1, 1)) + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(diff > error * -1) + assert (output.shape() == expect.shape)