batchnorm1d gpu op
This commit is contained in:
parent
a9ad992e43
commit
7ea4bd0593
|
@ -97,11 +97,6 @@ class BatchNormGpuKernel : public GpuKernel {
|
|||
|
||||
InitResource();
|
||||
is_train_ = GetAttr<bool>(kernel_node, "is_training");
|
||||
if (is_train_) {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
} else {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
}
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
|
||||
|
||||
|
@ -118,8 +113,8 @@ class BatchNormGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 4";
|
||||
if (shape.size() != 4 && shape.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 2D or 4D";
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(shape);
|
||||
if (is_null_input_) {
|
||||
|
@ -127,6 +122,15 @@ class BatchNormGpuKernel : public GpuKernel {
|
|||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (shape.size() == 2) {
|
||||
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
|
||||
} else if (is_train_) {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
} else {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
}
|
||||
|
||||
auto format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto format_attr = GetAttr<std::string>(kernel_node, "format");
|
||||
if (format_attr == kOpFormat_NHWC) {
|
||||
|
@ -242,7 +246,13 @@ class BatchNormGpuKernel : public GpuKernel {
|
|||
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
|
||||
cudnnTensorFormat_t cudnn_format;
|
||||
int batch, channel, height, width;
|
||||
if (format == kOpFormat_NHWC) {
|
||||
if (shape.size() == 2) {
|
||||
batch = SizeToInt(shape[0]);
|
||||
channel = SizeToInt(shape[1]);
|
||||
height = 1;
|
||||
width = 1;
|
||||
cudnn_format = CUDNN_TENSOR_NCHW;
|
||||
} else if (format == kOpFormat_NHWC) {
|
||||
batch = SizeToInt(shape[0]);
|
||||
height = SizeToInt(shape[1]);
|
||||
width = SizeToInt(shape[2]);
|
||||
|
|
|
@ -124,7 +124,6 @@ class BatchNormGradGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
InitResource();
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
|
@ -140,8 +139,8 @@ class BatchNormGradGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4";
|
||||
if (shape.size() != 4 && shape.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 2D or 4D";
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(shape);
|
||||
if (is_null_input_) {
|
||||
|
@ -149,6 +148,12 @@ class BatchNormGradGpuKernel : public GpuKernel {
|
|||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (shape.size() == 2) {
|
||||
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
|
||||
} else {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
}
|
||||
|
||||
std::string format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto format_attr = GetAttr<std::string>(kernel_node, "format");
|
||||
if (format_attr == kOpFormat_NHWC) {
|
||||
|
@ -234,7 +239,13 @@ class BatchNormGradGpuKernel : public GpuKernel {
|
|||
private:
|
||||
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
|
||||
cudnnTensorFormat_t cudnn_format;
|
||||
if (format == kOpFormat_NHWC) {
|
||||
if (shape.size() == 2) {
|
||||
batch_ = SizeToInt(shape[0]);
|
||||
channel_ = SizeToInt(shape[1]);
|
||||
height_ = 1;
|
||||
width_ = 1;
|
||||
cudnn_format = CUDNN_TENSOR_NCHW;
|
||||
} else if (format == kOpFormat_NHWC) {
|
||||
batch_ = SizeToInt(shape[0]);
|
||||
height_ = SizeToInt(shape[1]);
|
||||
width_ = SizeToInt(shape[2]);
|
||||
|
|
|
@ -288,7 +288,7 @@ class BatchNorm1d(_BatchNorm):
|
|||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
|
|
|
@ -18,7 +18,8 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import BatchNorm2d
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.nn import BatchNorm2d, BatchNorm1d, SGD
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
@ -201,3 +202,139 @@ def test_infer_backward():
|
|||
ms_grad = Grad(ms_net)
|
||||
ms_out_grad_np = ms_grad(ms_input, Tensor(input_grad_np))
|
||||
assert np.allclose(ms_out_grad_np[0].asnumpy(), expect_output)
|
||||
|
||||
|
||||
class BatchNorm1d_Net(Cell):
|
||||
def __init__(self, affine=True, gamma_init='ones', beta_init='zeros', moving_mean_init='zeros',
|
||||
moving_var_init='ones', use_batch_statistics=None):
|
||||
super(BatchNorm1d_Net, self).__init__()
|
||||
self.bn1 = BatchNorm1d(2, eps=0.00001, momentum=0.1, affine=affine, gamma_init=gamma_init, beta_init=beta_init,
|
||||
moving_mean_init=moving_mean_init, moving_var_init=moving_var_init,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.bn1(x)
|
||||
return x
|
||||
|
||||
class GradByListNet(Cell):
|
||||
def __init__(self, network):
|
||||
super(GradByListNet, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
|
||||
self.network = network
|
||||
self.params = ParameterTuple(network.trainable_params())
|
||||
|
||||
def construct(self, x, dy):
|
||||
grad_op = self.grad(self.network, self.params)
|
||||
output = grad_op(x, dy)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_1d_train():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
bn_net = BatchNorm1d_Net(use_batch_statistics=None)
|
||||
grad_net = GradByListNet(bn_net)
|
||||
optimizer = SGD(bn_net.trainable_params(), learning_rate=0.01, momentum=0.9)
|
||||
bn_net.set_train(True)
|
||||
|
||||
x1 = np.array([[1.6243454, -0.6117564],
|
||||
[-0.5281718, -1.0729686],
|
||||
[0.86540765, -2.3015387],
|
||||
[1.7448118, -0.7612069],
|
||||
[0.3190391, -0.24937038]]).astype(np.float32)
|
||||
dy1 = np.array([[1.4621079, -2.0601406],
|
||||
[-0.3224172, -0.38405436],
|
||||
[1.1337694, -1.0998913],
|
||||
[-0.1724282, -0.8778584],
|
||||
[0.04221375, 0.58281523]]).astype(np.float32)
|
||||
x2 = np.array([[-0.19183555, -0.887629],
|
||||
[-0.7471583, 1.6924546],
|
||||
[0.05080776, -0.6369957],
|
||||
[0.19091548, 2.1002553],
|
||||
[0.12015896, 0.6172031]]).astype(np.float32)
|
||||
dy2 = np.array([[0.30017033, -0.35224986],
|
||||
[-1.1425182, -0.34934273],
|
||||
[-0.20889424, 0.5866232],
|
||||
[0.8389834, 0.9311021],
|
||||
[0.2855873, 0.8851412]]).astype(np.float32)
|
||||
x_train = [x1, x2]
|
||||
dy_train = [dy1, dy2]
|
||||
|
||||
dx1 = np.array([[0.8120, -2.0371],
|
||||
[-0.2202, 0.5837],
|
||||
[0.8040, 0.1950],
|
||||
[-1.1823, -0.2786],
|
||||
[-0.2135, 1.5371]]).astype(np.float32)
|
||||
gamma1 = np.array([0.9821, 0.9873]).astype(np.float32)
|
||||
beta1 = np.array([-0.0214, 0.0384]).astype(np.float32)
|
||||
mean1 = np.array([0.7246, -0.8994]).astype(np.float32)
|
||||
variance1 = np.array([0.9036, 0.6559]).astype(np.float32)
|
||||
|
||||
dx2 = np.array([[1.1955, -0.4247],
|
||||
[-0.2425, -0.6789],
|
||||
[-1.4563, 0.3237],
|
||||
[0.8752, 0.3351],
|
||||
[-0.3719, 0.4448]]).astype(np.float32)
|
||||
gamma2 = np.array([0.9370, 0.9687]).astype(np.float32)
|
||||
beta2 = np.array([-0.0415, 0.0559]).astype(np.float32)
|
||||
mean2 = np.array([-0.0314, 0.4294]).astype(np.float32)
|
||||
variance2 = np.array([0.2213, 1.6822]).astype(np.float32)
|
||||
|
||||
exp_dx = [dx1, dx2]
|
||||
exp_gamma = [gamma1, gamma2]
|
||||
exp_beta = [beta1, beta2]
|
||||
exp_mean = [mean1, mean2]
|
||||
exp_variance = [variance1, variance2]
|
||||
|
||||
for data in zip(x_train, dy_train, exp_dx, exp_gamma, exp_beta, exp_mean, exp_variance):
|
||||
output = grad_net(Tensor(data[0]), Tensor(data[1]))
|
||||
assert np.allclose(output[0][0].asnumpy(), data[2], atol=1.0e-4)
|
||||
optimizer(output[1])
|
||||
assert np.allclose(bn_net.bn1.gamma.asnumpy(), data[3], atol=1.0e-4)
|
||||
assert np.allclose(bn_net.bn1.beta.asnumpy(), data[4], atol=1.0e-4)
|
||||
assert np.allclose(bn_net.bn1.moving_mean.asnumpy(), data[5], atol=1.0e-4)
|
||||
assert np.allclose(bn_net.bn1.moving_variance.asnumpy(), data[6], atol=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_1d_eval():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gamma_init = Tensor(np.array([0.93700373, 0.96870345]).astype(np.float32))
|
||||
beta_init = Tensor(np.array([-0.04145495, 0.05593072]).astype(np.float32))
|
||||
mean_init = Tensor(np.array([-0.03142229, 0.4294087]).astype(np.float32))
|
||||
variance_init = Tensor(np.array([0.2212921, 1.6822311]).astype(np.float32))
|
||||
bn_net = BatchNorm1d_Net(affine=False, gamma_init=gamma_init, beta_init=beta_init, moving_mean_init=mean_init,
|
||||
moving_var_init=variance_init, use_batch_statistics=None)
|
||||
bn_net.set_train(False)
|
||||
|
||||
x1 = np.array([[-1.1006192, 1.1447237],
|
||||
[0.9015907, 0.50249434],
|
||||
[0.90085596, -0.68372786],
|
||||
[-0.12289023, -0.93576944],
|
||||
[-0.26788807, 0.53035545]]).astype(np.float32)
|
||||
x2 = np.array([[-0.7543979, 1.2528682],
|
||||
[0.5129298, -0.29809284],
|
||||
[0.48851815, -0.07557172],
|
||||
[1.1316293, 1.5198169],
|
||||
[2.1855755, -1.3964963]]).astype(np.float32)
|
||||
x_test = [x1, x2]
|
||||
|
||||
y1 = np.array([[-2.1711, 0.5902],
|
||||
[1.8169, 0.1105],
|
||||
[1.8155, -0.7754],
|
||||
[-0.2236, -0.9637],
|
||||
[-0.5125, 0.1313]]).astype(np.float32)
|
||||
y2 = np.array([[-1.4815, 0.6710],
|
||||
[1.0428, -0.4874],
|
||||
[0.9942, -0.3212],
|
||||
[2.2751, 0.8703],
|
||||
[4.3744, -1.3078]]).astype(np.float32)
|
||||
y_test = [y1, y2]
|
||||
|
||||
for x, y in zip(x_test, y_test):
|
||||
y_pred = bn_net(Tensor(x))
|
||||
assert np.allclose(y_pred.asnumpy(), y, atol=1.0e-4)
|
||||
|
|
Loading…
Reference in New Issue