add test case for aware quantizaiton

This commit is contained in:
chenzomi 2020-04-13 15:39:51 +08:00
parent 60958d6b25
commit 652ab6c386
14 changed files with 456 additions and 44 deletions

View File

@ -38,14 +38,14 @@ class BatchNormFold2GpuKernel : public GpuKernel {
~BatchNormFold2GpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
if (is_null_input_) {
return true;
}
@ -66,7 +66,7 @@ class BatchNormFold2GpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) {
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
@ -98,9 +98,9 @@ class BatchNormFold2GpuKernel : public GpuKernel {
}
protected:
void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() {
void InitSizeLists() override {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = channel_ * sizeof(T);
input_size_list_.push_back(input_size);

View File

@ -38,14 +38,14 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
~BatchNormFold2GradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
if (is_null_input_) {
return true;
}
@ -88,7 +88,7 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) {
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
@ -120,9 +120,9 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
}
protected:
void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() {
void InitSizeLists() override {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = channel_ * sizeof(T);
size_t workspace_size = batch_size_ * channel_ * sizeof(T);

View File

@ -46,14 +46,14 @@ class BatchNormFoldGpuKernel : public GpuKernel {
~BatchNormFoldGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
(void)workspace;
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto mean = reinterpret_cast<T *>(inputs[1]->addr);
@ -104,7 +104,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) {
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
@ -152,7 +152,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
}
protected:
void InitSizeLists() {
void InitSizeLists() override {
// x, mean, variance, current_step
input_size_list_.push_back(input_size_);
input_size_list_.push_back(output_size_);
@ -169,7 +169,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
workspace_size_list_.push_back(input_size_);
}
void InitResource() {
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed");

View File

@ -42,11 +42,12 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
width_(0) {}
~BatchNormFoldGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
(void)workspace;
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step'
T *d_batch_mean = GetDeviceAddress<T>(inputs, 0);
@ -92,7 +93,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 6) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input.";
@ -128,7 +130,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
}
protected:
void InitSizeLists() {
void InitSizeLists() override {
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step'
input_size_list_.push_back(channel_size_);
input_size_list_.push_back(channel_size_);

View File

@ -30,11 +30,11 @@ class CorrectionMulGpuKernel : public GpuKernel {
CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {}
~CorrectionMulGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
auto *weight = GetDeviceAddress<T>(inputs, 0);
auto *gamma = GetDeviceAddress<T>(inputs, 1);
auto *running_std = GetDeviceAddress<T>(inputs, 2);
@ -44,7 +44,7 @@ class CorrectionMulGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
@ -69,7 +69,7 @@ class CorrectionMulGpuKernel : public GpuKernel {
}
protected:
void InitSizeLists() {
void InitSizeLists() override {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = batch_size_ * sizeof(T);
input_size_list_.push_back(input_size); // weight
@ -79,7 +79,7 @@ class CorrectionMulGpuKernel : public GpuKernel {
output_size_list_.push_back(input_size);
workspace_size_list_.push_back(workspace_size);
}
void InitResource() {}
void InitResource() override {}
private:
void DestroyResource() noexcept {}

View File

@ -30,11 +30,12 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {}
~CorrectionMulGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
auto *d_out = GetDeviceAddress<T>(inputs, 0);
auto *weight = GetDeviceAddress<T>(inputs, 1);
auto *gamma = GetDeviceAddress<T>(inputs, 2);
@ -49,7 +50,8 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
@ -74,7 +76,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
}
protected:
void InitSizeLists() {
void InitSizeLists() override {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = batch_size_ * sizeof(T);
input_size_list_.push_back(input_size); // d_out
@ -85,7 +87,7 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
output_size_list_.push_back(weight_size); // d_gamma
workspace_size_list_.push_back(input_size); // tmp d_out * weight
}
void InitResource() {}
void InitResource() override {}
private:
void DestroyResource() noexcept {}

View File

@ -369,7 +369,7 @@ class HSigmoid(Cell):
Hard sigmoid is defined as:
.. math::
\text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})),
\text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})),
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.

View File

@ -319,7 +319,7 @@ class HSigmoid(PrimitiveWithInfer):
Hard sigmoid is defined as:
.. math::
\text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})),
\text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})),
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.

View File

@ -0,0 +1,89 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import mindspore.context as context
context.set_context(device_target='GPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.BatchNormFold2(100000)
@ms_function
def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step):
return self.op(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step)
class Net_gnd(nn.Cell):
def __init__(self):
super(Net_gnd, self).__init__()
self.conv_mul = P.ConvMul(freeze_bn=100000)
self.correct_add = P.CorrectionAdd(freeze_bn=100000)
self.add_fold = P.AddFold()
@ms_function
def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step):
out = self.conv_mul(x, batch_std, running_std, current_step)
out = self.correct_add(out, gamma, batch_std, batch_mean,
running_std, running_mean, current_step)
out = self.add_fold(out, beta, gamma, batch_std, batch_mean)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchnrom_fold2():
net = Net()
c = 64
freeze_bn = 100000
x = np.random.uniform(-1, 1, size=[3, c, 32, 32]).astype('float32')
beta = np.random.uniform(1, 2, size=[c]).astype('float32')
gamma = np.random.uniform(1, 2, size=[c]).astype('float32')
batch_std = np.random.uniform(1, 2, size=[c]).astype('float32')
batch_mean = np.random.uniform(1, 2, size=[c]).astype('float32')
running_std = np.random.uniform(1, 2, size=[c]).astype('float32')
running_mean = np.random.uniform(1, 2, size=[c]).astype('float32')
current_step = np.array([0]).astype('int32')
output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean),
Tensor(running_std), Tensor(running_mean), Tensor(current_step))
expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1,
1) if current_step >= freeze_bn else
x * (running_std / batch_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1,
1))
error = np.ones(shape=expect.shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(diff > error * -1)
current_step = np.array([100000]).astype('int32')
output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean), Tensor(running_std),
Tensor(running_mean), Tensor(current_step))
expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1,
1) if current_step >= freeze_bn else
x * (batch_std / running_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1,
1))
error = np.ones(shape=expect.shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(diff > error * -1)

View File

@ -0,0 +1,96 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import mindspore.context as context
context.set_context(device_target='GPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.BatchNormFoldGrad(freeze_bn=10)
@ms_function
def construct(self, d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step):
dx = self.op(d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step)
return dx
def np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std):
n = x.shape[0] * x.shape[2] * x.shape[3]
dx = d_batch_mean.reshape(1, -1, 1, 1) / n + d_batch_std.reshape(1, -1, 1, 1) * (
x - batch_mean.reshape(1, -1, 1, 1)) / batch_std.reshape(1, -1, 1, 1) / n
return dx
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchnorm_fold_grad1():
net = Net()
c = 64
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
current_step = np.array([0]).astype('int32')
dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std),
Tensor(current_step))
expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std)
assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchnorm_fold_grad2():
net = Net()
c = 64
x = np.random.uniform(1, 10, size=[1, c, 256, 256]).astype('float32')
d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
current_step = np.array([0]).astype('int32')
dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std),
Tensor(current_step))
expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std)
assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchnorm_fold_grad_freeze():
net = Net()
c = 64
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
current_step = np.array([10]).astype('int32')
dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std),
Tensor(current_step))
expect = np.zeros_like(x)
assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7)

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import mindspore.context as context
context.set_context(device_target='GPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.BatchNormFold(freeze_bn=10)
@ms_function
def construct(self, x, mean, variance, current_step):
a, b, c, d = self.op(x, mean, variance, current_step)
return a, b, c, d
def np_result(x, mean, var, momentum, epsilon):
np_mean = x.mean(axis=(0, 2, 3))
np_var = x.var(axis=(0, 2, 3))
n = x.shape[0] * x.shape[2] * x.shape[3]
mean_update = momentum * np_mean + (1 - momentum) * mean
var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var
np_var = np.sqrt(np_var + epsilon)
delay_mean = mean.copy()
delay_std = np.sqrt(var + epsilon)
return np_mean, np_var, mean_update, var_update, delay_mean, delay_std
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchnorm_fold():
net = Net()
c = 64
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
mean = np.random.uniform(1, 10, size=[c]).astype('float32')
variance = np.random.uniform(1, 10, size=[c]).astype('float32')
current_step = np.array([0]).astype('int32')
ms_mean = Tensor(mean)
ms_var = Tensor(variance)
batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var,
Tensor(current_step))
expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12)
assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5)
assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5)
assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5)
assert np.allclose(ms_var.asnumpy(), expect4, rtol=1.e-7, atol=1.e-5)
assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5)
assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchnorm_fold2():
net = Net()
c = 64
x = np.random.uniform(1, 10, size=[3, c, 512, 512]).astype('float32')
mean = np.random.uniform(1, 10, size=[c]).astype('float32')
variance = np.random.uniform(1, 10, size=[c]).astype('float32')
current_step = np.array([0]).astype('int32')
ms_mean = Tensor(mean)
ms_var = Tensor(variance)
batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var,
Tensor(current_step))
expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12)
assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5)
assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5)
assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5)
assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5)
assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_batchnorm_fold_freeze():
net = Net()
c = 64
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
mean = np.random.uniform(1, 10, size=[c]).astype('float32')
variance = np.random.uniform(1, 10, size=[c]).astype('float32')
current_step = np.array([10]).astype('int32')
ms_mean = Tensor(mean)
ms_var = Tensor(variance)
batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var,
Tensor(current_step))
expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12)
assert np.allclose(batch_mean.asnumpy(), np.zeros_like(mean), rtol=1.e-7, atol=1.e-5)
assert np.allclose(batch_var.asnumpy(), np.ones_like(mean), rtol=1.e-7, atol=1.e-5)
assert np.allclose(ms_mean.asnumpy(), mean, rtol=1.e-7, atol=1.e-5)
assert np.allclose(ms_var.asnumpy(), variance, rtol=1.e-7, atol=1.e-5)
assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5)
assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5)

View File

@ -14,10 +14,10 @@
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import numpy as np
import mindspore.context as context

View File

@ -0,0 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import os
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import mindspore.context as context
context.set_context(device_target='GPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op_w = P.CorrectionMulGrad()
@ms_function
def construct(self, dy, x, batch_std, running_std):
dx, d_batch_std = self.op_w(dy, x, batch_std, running_std)
return dx, d_batch_std
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_correction_mul_grad():
net = Net()
co, ci, h, w = 64, 1, 32, 32
dout = np.random.uniform(-0.1, 0.1, size=[co, ci, h, w]).astype('float32')
x = np.random.uniform(1, 1, size=[co, ci, h, w]).astype('float32')
batch_std = np.random.uniform(1, 10, size=[co]).astype('float32')
running_std = np.random.uniform(1, 10, size=[co]).astype('float32')
output = net(Tensor(dout), Tensor(x), Tensor(batch_std), Tensor(running_std))
expect = [0, 0]
expect[0] = (dout * np.reshape(batch_std / running_std, (co, 1, 1, 1)))
expect[1] = (np.sum(dout * x, (1, 2, 3)) / running_std)
for i, v in enumerate(output):
assert (np.allclose(output[i].asnumpy(), expect[i], rtol=1.e-5, atol=1.e-5))

View File

@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import mindspore.context as context
context.set_context(device_target='GPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.CorrectionMul()
@ms_function
def construct(self, x, batch_var, moving_var):
return self.op(x, batch_var, moving_var)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_correction_mul():
net = Net()
co = 64
x = np.random.uniform(-1, 1, size=[co, 64, 32, 32]).astype('float32')
bv = np.random.uniform(1, 2, size=[co]).astype('float32')
mv = np.random.uniform(1, 2, size=[co]).astype('float32')
output = net(Tensor(x), Tensor(bv), Tensor(mv))
expect = x * np.reshape(bv, (co, 1, 1, 1)) / np.reshape(mv, (co, 1, 1, 1))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(diff > error * -1)
assert (output.shape() == expect.shape)