batchnorm1d gpu op

This commit is contained in:
tom__chen 2021-04-20 17:40:39 -04:00
parent a9ad992e43
commit 7ea4bd0593
4 changed files with 172 additions and 14 deletions

View File

@ -97,11 +97,6 @@ class BatchNormGpuKernel : public GpuKernel {
InitResource();
is_train_ = GetAttr<bool>(kernel_node, "is_training");
if (is_train_) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL;
}
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
@ -118,8 +113,8 @@ class BatchNormGpuKernel : public GpuKernel {
}
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 4";
if (shape.size() != 4 && shape.size() != 2) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 2D or 4D";
}
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
@ -127,6 +122,15 @@ class BatchNormGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
if (shape.size() == 2) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else if (is_train_) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL;
}
auto format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "format");
if (format_attr == kOpFormat_NHWC) {
@ -242,7 +246,13 @@ class BatchNormGpuKernel : public GpuKernel {
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
cudnnTensorFormat_t cudnn_format;
int batch, channel, height, width;
if (format == kOpFormat_NHWC) {
if (shape.size() == 2) {
batch = SizeToInt(shape[0]);
channel = SizeToInt(shape[1]);
height = 1;
width = 1;
cudnn_format = CUDNN_TENSOR_NCHW;
} else if (format == kOpFormat_NHWC) {
batch = SizeToInt(shape[0]);
height = SizeToInt(shape[1]);
width = SizeToInt(shape[2]);

View File

@ -124,7 +124,6 @@ class BatchNormGradGpuKernel : public GpuKernel {
}
InitResource();
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
@ -140,8 +139,8 @@ class BatchNormGradGpuKernel : public GpuKernel {
}
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4";
if (shape.size() != 4 && shape.size() != 2) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 2D or 4D";
}
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
@ -149,6 +148,12 @@ class BatchNormGradGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
if (shape.size() == 2) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
}
std::string format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "format");
if (format_attr == kOpFormat_NHWC) {
@ -234,7 +239,13 @@ class BatchNormGradGpuKernel : public GpuKernel {
private:
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
cudnnTensorFormat_t cudnn_format;
if (format == kOpFormat_NHWC) {
if (shape.size() == 2) {
batch_ = SizeToInt(shape[0]);
channel_ = SizeToInt(shape[1]);
height_ = 1;
width_ = 1;
cudnn_format = CUDNN_TENSOR_NCHW;
} else if (format == kOpFormat_NHWC) {
batch_ = SizeToInt(shape[0]);
height_ = SizeToInt(shape[1]);
width_ = SizeToInt(shape[2]);

View File

@ -288,7 +288,7 @@ class BatchNorm1d(_BatchNorm):
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU``
Raises:
TypeError: If `num_features` is not an int.

View File

@ -18,7 +18,8 @@ import pytest
import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore.nn import BatchNorm2d
from mindspore.common.parameter import ParameterTuple
from mindspore.nn import BatchNorm2d, BatchNorm1d, SGD
from mindspore.nn import Cell
from mindspore.ops import composite as C
@ -201,3 +202,139 @@ def test_infer_backward():
ms_grad = Grad(ms_net)
ms_out_grad_np = ms_grad(ms_input, Tensor(input_grad_np))
assert np.allclose(ms_out_grad_np[0].asnumpy(), expect_output)
class BatchNorm1d_Net(Cell):
def __init__(self, affine=True, gamma_init='ones', beta_init='zeros', moving_mean_init='zeros',
moving_var_init='ones', use_batch_statistics=None):
super(BatchNorm1d_Net, self).__init__()
self.bn1 = BatchNorm1d(2, eps=0.00001, momentum=0.1, affine=affine, gamma_init=gamma_init, beta_init=beta_init,
moving_mean_init=moving_mean_init, moving_var_init=moving_var_init,
use_batch_statistics=use_batch_statistics)
def construct(self, x):
x = self.bn1(x)
return x
class GradByListNet(Cell):
def __init__(self, network):
super(GradByListNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, x, dy):
grad_op = self.grad(self.network, self.params)
output = grad_op(x, dy)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_1d_train():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
bn_net = BatchNorm1d_Net(use_batch_statistics=None)
grad_net = GradByListNet(bn_net)
optimizer = SGD(bn_net.trainable_params(), learning_rate=0.01, momentum=0.9)
bn_net.set_train(True)
x1 = np.array([[1.6243454, -0.6117564],
[-0.5281718, -1.0729686],
[0.86540765, -2.3015387],
[1.7448118, -0.7612069],
[0.3190391, -0.24937038]]).astype(np.float32)
dy1 = np.array([[1.4621079, -2.0601406],
[-0.3224172, -0.38405436],
[1.1337694, -1.0998913],
[-0.1724282, -0.8778584],
[0.04221375, 0.58281523]]).astype(np.float32)
x2 = np.array([[-0.19183555, -0.887629],
[-0.7471583, 1.6924546],
[0.05080776, -0.6369957],
[0.19091548, 2.1002553],
[0.12015896, 0.6172031]]).astype(np.float32)
dy2 = np.array([[0.30017033, -0.35224986],
[-1.1425182, -0.34934273],
[-0.20889424, 0.5866232],
[0.8389834, 0.9311021],
[0.2855873, 0.8851412]]).astype(np.float32)
x_train = [x1, x2]
dy_train = [dy1, dy2]
dx1 = np.array([[0.8120, -2.0371],
[-0.2202, 0.5837],
[0.8040, 0.1950],
[-1.1823, -0.2786],
[-0.2135, 1.5371]]).astype(np.float32)
gamma1 = np.array([0.9821, 0.9873]).astype(np.float32)
beta1 = np.array([-0.0214, 0.0384]).astype(np.float32)
mean1 = np.array([0.7246, -0.8994]).astype(np.float32)
variance1 = np.array([0.9036, 0.6559]).astype(np.float32)
dx2 = np.array([[1.1955, -0.4247],
[-0.2425, -0.6789],
[-1.4563, 0.3237],
[0.8752, 0.3351],
[-0.3719, 0.4448]]).astype(np.float32)
gamma2 = np.array([0.9370, 0.9687]).astype(np.float32)
beta2 = np.array([-0.0415, 0.0559]).astype(np.float32)
mean2 = np.array([-0.0314, 0.4294]).astype(np.float32)
variance2 = np.array([0.2213, 1.6822]).astype(np.float32)
exp_dx = [dx1, dx2]
exp_gamma = [gamma1, gamma2]
exp_beta = [beta1, beta2]
exp_mean = [mean1, mean2]
exp_variance = [variance1, variance2]
for data in zip(x_train, dy_train, exp_dx, exp_gamma, exp_beta, exp_mean, exp_variance):
output = grad_net(Tensor(data[0]), Tensor(data[1]))
assert np.allclose(output[0][0].asnumpy(), data[2], atol=1.0e-4)
optimizer(output[1])
assert np.allclose(bn_net.bn1.gamma.asnumpy(), data[3], atol=1.0e-4)
assert np.allclose(bn_net.bn1.beta.asnumpy(), data[4], atol=1.0e-4)
assert np.allclose(bn_net.bn1.moving_mean.asnumpy(), data[5], atol=1.0e-4)
assert np.allclose(bn_net.bn1.moving_variance.asnumpy(), data[6], atol=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_1d_eval():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gamma_init = Tensor(np.array([0.93700373, 0.96870345]).astype(np.float32))
beta_init = Tensor(np.array([-0.04145495, 0.05593072]).astype(np.float32))
mean_init = Tensor(np.array([-0.03142229, 0.4294087]).astype(np.float32))
variance_init = Tensor(np.array([0.2212921, 1.6822311]).astype(np.float32))
bn_net = BatchNorm1d_Net(affine=False, gamma_init=gamma_init, beta_init=beta_init, moving_mean_init=mean_init,
moving_var_init=variance_init, use_batch_statistics=None)
bn_net.set_train(False)
x1 = np.array([[-1.1006192, 1.1447237],
[0.9015907, 0.50249434],
[0.90085596, -0.68372786],
[-0.12289023, -0.93576944],
[-0.26788807, 0.53035545]]).astype(np.float32)
x2 = np.array([[-0.7543979, 1.2528682],
[0.5129298, -0.29809284],
[0.48851815, -0.07557172],
[1.1316293, 1.5198169],
[2.1855755, -1.3964963]]).astype(np.float32)
x_test = [x1, x2]
y1 = np.array([[-2.1711, 0.5902],
[1.8169, 0.1105],
[1.8155, -0.7754],
[-0.2236, -0.9637],
[-0.5125, 0.1313]]).astype(np.float32)
y2 = np.array([[-1.4815, 0.6710],
[1.0428, -0.4874],
[0.9942, -0.3212],
[2.2751, 0.8703],
[4.3744, -1.3078]]).astype(np.float32)
y_test = [y1, y2]
for x, y in zip(x_test, y_test):
y_pred = bn_net(Tensor(x))
assert np.allclose(y_pred.asnumpy(), y, atol=1.0e-4)