mindspore/tests/st/networks/test_cell_bprop.py

424 lines
13 KiB
Python
Raw Normal View History

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_cell_bprop """
import numpy as np
2020-05-18 16:42:35 +08:00
import pytest
2020-07-06 17:26:28 +08:00
import mindspore as ms
2020-05-18 16:42:35 +08:00
import mindspore.common.dtype as mstype
import mindspore.nn as nn
2020-08-06 09:37:51 +08:00
from mindspore import Parameter, ParameterTuple
from mindspore import context
2020-05-18 16:42:35 +08:00
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
2020-08-06 09:37:51 +08:00
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
2020-08-24 10:22:10 +08:00
2020-08-25 20:16:08 +08:00
grad_all = C.GradOperation(get_all=True)
2020-08-24 10:22:10 +08:00
class MulAdd(nn.Cell):
def construct(self, x, y):
return 2 * x + y
def bprop(self, x, y, out, dout):
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
return 2 * dout, 2 * y
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_mul_add():
mul_add = MulAdd()
2020-07-06 17:26:28 +08:00
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
2020-08-24 10:22:10 +08:00
assert grad_all(mul_add)(x, y) == (2, 4)
class InlineMulADD(nn.Cell):
def __init__(self):
super(InlineMulADD, self).__init__()
self.mul_add = MulAdd()
2020-04-17 12:03:50 +08:00
self.param = 2
def construct(self, x, y):
return self.mul_add(x, y) + x + self.param * y
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_inline_mul_add():
inline_mul_add = InlineMulADD()
2020-07-06 17:26:28 +08:00
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
2020-08-24 10:22:10 +08:00
assert grad_all(inline_mul_add)(x, y) == (3, 6)
class WithParameter(nn.Cell):
def __init__(self):
super(WithParameter, self).__init__()
2020-04-16 16:18:47 +08:00
self.param1 = Parameter(1, 'param1')
self.param2 = Parameter(2, 'param2')
def construct(self, x, y):
2020-04-16 16:18:47 +08:00
return self.param1 * self.param2 * x + y
def bprop(self, x, y, out, dout):
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
2020-04-16 16:18:47 +08:00
return self.param1 * self.param2 * dout, 2 * y
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_with_param():
with_param = WithParameter()
with pytest.raises(RuntimeError):
2020-08-24 10:22:10 +08:00
grad_all(with_param)(1, 2)
2020-05-18 10:31:46 +08:00
class WithNoBprop(nn.Cell):
def construct(self, x, y):
return 2 * x + y
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_with_no_bprop():
with_no_bprop = WithNoBprop()
2020-07-06 17:26:28 +08:00
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
2020-08-24 10:22:10 +08:00
assert grad_all(with_no_bprop)(x, y) == (2, 1)
2020-05-18 10:31:46 +08:00
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_in_bprop_1():
class GradInBprop_1(nn.Cell):
def __init__(self):
super(GradInBprop_1, self).__init__()
self.relu = P.ReLU()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
return self.relu(x)
2020-05-18 10:31:46 +08:00
class GradInBprop_2(nn.Cell):
def __init__(self):
super(GradInBprop_2, self).__init__()
self.f = GradInBprop_1()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
2020-08-24 10:22:10 +08:00
return self.f(x, y), grad_all(self.f)(x, y)
2020-05-18 10:31:46 +08:00
def bprop(self, x, y, out, dout):
2020-08-24 10:22:10 +08:00
grads = grad_all(self.f)(x, y)
return out[1][0], grads[1]
2020-05-18 10:31:46 +08:00
class GradInBprop_3(nn.Cell):
def __init__(self):
super(GradInBprop_3, self).__init__()
self.f = GradInBprop_2()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
return self.f(x, y)
2020-05-18 10:31:46 +08:00
grad_in_bprop = GradInBprop_3()
2020-08-24 10:22:10 +08:00
grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
Tensor(np.ones([2, 2]).astype(np.float32)))
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_in_bprop_2():
class GradInBprop_1(nn.Cell):
def __init__(self):
super(GradInBprop_1, self).__init__()
self.relu = P.ReLU()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
return self.relu(x)
2020-05-18 10:31:46 +08:00
def bprop(self, x, y, out, dout):
return x * y, y + x
2020-05-18 10:31:46 +08:00
class GradInBprop_2(nn.Cell):
def __init__(self):
super(GradInBprop_2, self).__init__()
self.f = GradInBprop_1()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
2020-08-24 10:22:10 +08:00
return self.f(x, y), grad_all(self.f)(x, y)
2020-05-18 10:31:46 +08:00
def bprop(self, x, y, out, dout):
2020-08-24 10:22:10 +08:00
grads = grad_all(self.f)(x, y)
return out[1][0], grads[1]
2020-05-18 10:31:46 +08:00
class GradInBprop_3(nn.Cell):
def __init__(self):
super(GradInBprop_3, self).__init__()
self.f = GradInBprop_2()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
return self.f(x, y)
2020-05-18 10:31:46 +08:00
grad_in_bprop = GradInBprop_3()
2020-08-24 10:22:10 +08:00
grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
Tensor(np.ones([2, 2]).astype(np.float32)))
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_in_bprop_3():
class GradInBprop_1(nn.Cell):
def __init__(self):
super(GradInBprop_1, self).__init__()
self.relu = P.ReLU()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
return self.relu(x)
2020-05-18 10:31:46 +08:00
class GradInBprop_2(nn.Cell):
def __init__(self):
super(GradInBprop_2, self).__init__()
self.f = GradInBprop_1()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
2020-08-24 10:22:10 +08:00
return self.f(x, y), grad_all(self.f)(x, y)
2020-05-18 10:31:46 +08:00
def bprop(self, x, y, out, dout):
2020-08-24 10:22:10 +08:00
grads = grad_all(self.f)(x, y)
return out[1][0], grads[1]
2020-05-18 10:31:46 +08:00
class GradInBprop_3(nn.Cell):
def __init__(self):
super(GradInBprop_3, self).__init__()
self.f = GradInBprop_2()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
return self.f(x, y)
2020-05-18 10:31:46 +08:00
def bprop(self, x, y, out, dout):
return x + y + y + out[0], x + x + y + y + dout[0]
2020-05-18 10:31:46 +08:00
grad_in_bprop = GradInBprop_3()
2020-08-24 10:22:10 +08:00
grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
Tensor(np.ones([2, 2]).astype(np.float32)))
assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all()
2020-05-18 10:31:46 +08:00
class OneInputBprop(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.ReLU()
2020-05-18 10:31:46 +08:00
def construct(self, x):
return self.op(x)
2020-05-18 10:31:46 +08:00
def bprop(self, x, out, dout):
return (5 * x,)
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_one_input_bprop():
net = OneInputBprop()
2020-05-14 11:32:12 +08:00
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
2020-08-24 10:22:10 +08:00
grad = grad_all(net)(input1)
assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all()
class TwoInput(nn.Cell):
def construct(self, x, y):
return x * y
2020-05-18 10:31:46 +08:00
class InlineBpropTwoInput(nn.Cell):
def __init__(self):
super().__init__()
self.f = TwoInput()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
2020-08-24 10:22:10 +08:00
return self.f(x, y), grad_all(self.f)(x, y)
2020-05-18 10:31:46 +08:00
def bprop(self, x, y, out, dout):
2020-08-24 10:22:10 +08:00
grads = grad_all(self.f)(x, y)
return grads[0] * 2, grads[1] * 2
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_inline_bprop_two_input():
net = InlineBpropTwoInput()
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
input2 = Tensor(np.ones([2, 2]).astype(np.float32))
2020-08-24 10:22:10 +08:00
grads = grad_all(net)(input1, input2)
2020-08-06 09:37:51 +08:00
assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
assert len(grads) == 2
class TwoInputBprop(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
return self.op(x, y)
2020-05-18 10:31:46 +08:00
def bprop(self, x, y, out, dout):
return 5 * x, 8 * y
class TwoInputWithParameter(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
2020-05-18 10:31:46 +08:00
self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step")
def construct(self, x, y):
x = self.inputdata + x
2020-05-18 10:31:46 +08:00
return self.op(x, y)
class TwoInputWithOnlyInitParameterBprop(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
2020-05-18 10:31:46 +08:00
self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step")
def construct(self, x, y):
2020-05-18 10:31:46 +08:00
return self.op(x, y)
def bprop(self, x, y, out, dout):
2020-05-18 10:31:46 +08:00
return 5 * x, 8 * y
class InlineMutilTwoInputParameterCell(nn.Cell):
def __init__(self):
super().__init__()
self.f1 = TwoInputBprop()
self.f2 = TwoInput()
self.f3 = TwoInputWithParameter()
self.f4 = TwoInputWithOnlyInitParameterBprop()
2020-05-18 10:31:46 +08:00
def construct(self, x, y):
2020-05-18 10:31:46 +08:00
output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y)
return output
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_inline_bprop_multi_input():
net = InlineMutilTwoInputParameterCell()
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
input2 = Tensor(np.ones([2, 2]).astype(np.float32))
net.init_parameters_data()
2020-08-24 10:22:10 +08:00
grads = grad_all(net)(input1, input2)
assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all()
assert len(grads) == 2
2020-05-18 10:31:46 +08:00
class MulAddWithParam(nn.Cell):
def __init__(self):
super(MulAddWithParam, self).__init__()
self.mul_add = MulAdd()
self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param')
2020-05-18 10:31:46 +08:00
def construct(self, x):
return self.mul_add(self.param, x)
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_refkey_bprop():
2020-08-25 20:16:08 +08:00
grad_by_list = C.GradOperation(get_all=True, get_by_list=True)
2020-08-06 09:37:51 +08:00
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x):
weights = self.weights
grads = grad_by_list(self.network, weights)(x)
return grads
network = GradWrap(MulAddWithParam())
input_data = Tensor(np.array([2, 2], np.float32))
2020-08-06 09:37:51 +08:00
grads = network(input_data)
assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all()
assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
2020-08-06 09:37:51 +08:00
class MulAddWithWrongOutputNum(nn.Cell):
def construct(self, x, y):
return 2 * x + y
2020-05-18 10:31:46 +08:00
2020-08-06 09:37:51 +08:00
def bprop(self, x, y, out, dout):
return (2 * dout,)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_mul_add_with_wrong_output_num():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputNum()
with pytest.raises(TypeError):
2020-08-24 10:22:10 +08:00
grad_all(mul_add)(1, 2)
2020-08-06 09:37:51 +08:00
class MulAddWithWrongOutputType(nn.Cell):
2020-04-23 15:42:11 +08:00
def construct(self, x, y):
return 2 * x + y
2020-05-18 10:31:46 +08:00
2020-04-23 15:42:11 +08:00
def bprop(self, x, y, out, dout):
return 2 * dout, 2
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
2020-04-23 15:42:11 +08:00
def test_grad_mul_add_with_wrong_output_type():
2020-05-21 14:42:20 +08:00
context.set_context(check_bprop=True)
2020-04-23 15:42:11 +08:00
mul_add = MulAddWithWrongOutputType()
with pytest.raises(TypeError):
2020-08-24 10:22:10 +08:00
grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
2020-04-23 15:42:11 +08:00
class MulAddWithWrongOutputShape(nn.Cell):
def __init__(self):
super(MulAddWithWrongOutputShape, self).__init__()
self.ones = Tensor(np.ones([2,]))
2020-05-18 10:31:46 +08:00
2020-04-23 15:42:11 +08:00
def construct(self, x, y):
return 2 * x + y
2020-05-18 10:31:46 +08:00
2020-04-23 15:42:11 +08:00
def bprop(self, x, y, out, dout):
return 2, self.ones
2020-08-06 09:37:51 +08:00
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
2020-04-23 15:42:11 +08:00
def test_grad_mul_add_with_wrong_output_shape():
2020-05-21 14:42:20 +08:00
context.set_context(check_bprop=True)
2020-04-23 15:42:11 +08:00
mul_add = MulAddWithWrongOutputShape()
with pytest.raises(TypeError):
2020-08-24 10:22:10 +08:00
grad_all(mul_add)(1, Tensor(np.ones([2, 2])))