|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
|
|
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
@ -17,18 +17,24 @@ import numpy as np
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import Tensor, Parameter
|
|
|
|
|
from mindspore.common import mutable
|
|
|
|
|
from mindspore import Tensor, Parameter, ParameterTuple
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
import mindspore.ops as ops
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module", autouse=True)
|
|
|
|
|
def setup_teardown():
|
|
|
|
|
yield
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FirstInputTupleNet(nn.Cell):
|
|
|
|
|
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
|
|
|
|
def construct(self, tuple_a, tensor_a, list_b, tensor_b, scalar, dict_c, flag):
|
|
|
|
|
if flag:
|
|
|
|
|
return tensor_x - tuple_a[2] + list_b[1][1]["x"] - tensor_y + scalar - dict_c["x"]
|
|
|
|
|
return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - scalar + dict_c["y"]
|
|
|
|
|
return tensor_a - tuple_a[2] + list_b[1][1]["x"] - tensor_b + scalar - dict_c["x"]
|
|
|
|
|
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - scalar + dict_c["y"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradNet(nn.Cell):
|
|
|
|
@ -38,8 +44,8 @@ class GradNet(nn.Cell):
|
|
|
|
|
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
|
|
|
|
self.grad_all = C.GradOperation(get_all=get_all)
|
|
|
|
|
|
|
|
|
|
def construct(self, tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag):
|
|
|
|
|
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, scalar, dict_c, flag)
|
|
|
|
|
def construct(self, tuple_a, tensor_a, list_b, tensor_b, scalar, dict_c, flag):
|
|
|
|
|
return self.grad_all(self.forward_net)(tuple_a, tensor_a, list_b, tensor_b, scalar, dict_c, flag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradNet1(nn.Cell):
|
|
|
|
@ -49,68 +55,312 @@ class GradNet1(nn.Cell):
|
|
|
|
|
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
|
|
|
|
self.grad_all = C.GradOperation(get_all=get_all)
|
|
|
|
|
|
|
|
|
|
def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c):
|
|
|
|
|
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c)
|
|
|
|
|
def construct(self, tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c):
|
|
|
|
|
return self.grad_all(self.forward_net)(tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones((2, 2), np.float32))
|
|
|
|
|
y = Tensor(np.ones((2, 2), np.float32) * 2)
|
|
|
|
|
z = Tensor(np.ones((2, 2), np.float32) * 3)
|
|
|
|
|
w = Tensor(np.ones((2, 2), np.float32) * 4)
|
|
|
|
|
sl = 6
|
|
|
|
|
s = "ok"
|
|
|
|
|
arg_t0 = (x, y, z, w)
|
|
|
|
|
arg_t1 = (w, y, z, w)
|
|
|
|
|
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
|
|
|
|
arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]]
|
|
|
|
|
args_d0 = {"x": x, "y": y}
|
|
|
|
|
args_d1 = {"x": x, "y": y}
|
|
|
|
|
tensor_x = Tensor(np.ones((2, 2), np.float32))
|
|
|
|
|
tensor_y = Tensor(np.ones((2, 2), np.float32) * 2)
|
|
|
|
|
tensor_z = Tensor(np.ones((2, 2), np.float32) * 3)
|
|
|
|
|
tensor_w = Tensor(np.ones((2, 2), np.float32) * 4)
|
|
|
|
|
SCALAR_NUM = 6
|
|
|
|
|
STRING_INPUT = "ok"
|
|
|
|
|
tuple_arg = (tensor_x, tensor_y, tensor_z, tensor_w)
|
|
|
|
|
list_arg = [[tensor_x, tensor_x], [[tensor_x, tensor_y], {"x": tensor_x, "y": tensor_y, "z": tensor_x, "p": tensor_y}]]
|
|
|
|
|
dict_arg = {"x": tensor_x, "y": tensor_y}
|
|
|
|
|
flag_0 = True
|
|
|
|
|
flag_1 = False
|
|
|
|
|
|
|
|
|
|
p = Parameter(x, name="weight")
|
|
|
|
|
a = np.ones((2, 2))
|
|
|
|
|
parameter_x = Parameter(tensor_x, name="weight")
|
|
|
|
|
|
|
|
|
|
forward_net = FirstInputTupleNet()
|
|
|
|
|
forward_net.set_grad()
|
|
|
|
|
grad_all_inputs_net = GradNet(forward_net, get_all=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_first_input_net():
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
|
|
|
|
def test_grad_first_input_net(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Normal input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
class FirstInputTensorNet(nn.Cell):
|
|
|
|
|
def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c):
|
|
|
|
|
return tensor_x + tuple_a[2] - list_b[1][1]["y"] + tensor_y - tensor_z + dict_c["y"]
|
|
|
|
|
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
|
|
|
|
|
return tensor_a + tuple_a[0] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
|
|
|
|
|
ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, y, args_d0)
|
|
|
|
|
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
|
|
|
|
|
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
|
|
|
|
|
print('res:', res)
|
|
|
|
|
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_net_inputs_including_str():
|
|
|
|
|
# PyNative run error.
|
|
|
|
|
# Support context.PYNATIVE_MODE later.
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
|
|
|
def test_grad_first_input_net_pynative_error(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Normal input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
class FirstInputTensorNet(nn.Cell):
|
|
|
|
|
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
|
|
|
|
|
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
|
|
|
|
|
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
|
|
|
|
|
print('res:', res)
|
|
|
|
|
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
|
|
|
|
def test_net_inputs_including_str(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: String input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
with pytest.raises(TypeError) as err:
|
|
|
|
|
grad_all_inputs_net(arg_t0, s, arg_l0, w, sl, args_d0, flag_0)
|
|
|
|
|
assert "The inputs types of the outermost network 'GradNet.construct' support bool, int, float, None, Tensor, " \
|
|
|
|
|
grad_all_inputs_net(tuple_arg, STRING_INPUT, list_arg, tensor_w, SCALAR_NUM, dict_arg, flag_0)
|
|
|
|
|
print('err: ', str(err.value))
|
|
|
|
|
# network is 'GradNet.construct' in GraphMode.
|
|
|
|
|
# network is 'FirstInputTupleNet.construct' in PynativeMode.
|
|
|
|
|
assert "The inputs types of the outermost network" in str(err.value)
|
|
|
|
|
assert "support bool, int, float, None, Tensor, " \
|
|
|
|
|
"Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
|
|
|
|
"and tuple or list containing only these types, and dict whose values are these types, " \
|
|
|
|
|
"but the 1th arg type is <class 'str'>, value is 'ok'" in str(err.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Support the Parameter as outermost input.
|
|
|
|
|
def test_outermost_net_pass_parameter():
|
|
|
|
|
forward_net(arg_t0, p, arg_l0, w, sl, args_d0, flag_0)
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
|
|
|
|
def test_outermost_net_pass_parameter(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Parameter input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
forward_net(tuple_arg, parameter_x, list_arg, tensor_w, SCALAR_NUM, dict_arg, flag_0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Support the Parameter as outermost input.
|
|
|
|
|
def test_outermost_net_pass_tuple_including_parameter():
|
|
|
|
|
forward_net(arg_t0, z, arg_l0, sl, (z, w, p), args_d0, flag_0)
|
|
|
|
|
# Support context.PYNATIVE_MODE UT later.
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
|
|
|
def test_outermost_net_pass_tuple_including_parameter(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Tuple with Parameter as input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
mutable_tuple = mutable((tensor_z, tensor_w, parameter_x))
|
|
|
|
|
forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, mutable_tuple, dict_arg, flag_0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Support the Parameter as outermost input.
|
|
|
|
|
def test_outermost_net_pass_list_including_parameter():
|
|
|
|
|
forward_net(arg_t0, z, arg_l0, sl, [z, w, p], args_d0, flag_0)
|
|
|
|
|
# Support context.PYNATIVE_MODE UT later.
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
|
|
|
def test_outermost_net_pass_list_including_parameter(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: List with Parameter as input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
mutable_list = mutable([tensor_z, tensor_w, parameter_x])
|
|
|
|
|
forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, mutable_list, dict_arg, flag_0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Support the Parameter as outermost input.
|
|
|
|
|
def test_grad_net_pass_dict_including_parameter():
|
|
|
|
|
with pytest.raises(RuntimeError) as err:
|
|
|
|
|
forward_net(arg_t0, z, arg_l0, sl, sl, {"x": z, "y": w, "z": p}, flag_0)
|
|
|
|
|
assert "Illegal type in the graph: AbstractDictionary" in str(err.value)
|
|
|
|
|
# Support context.PYNATIVE_MODE UT later.
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
|
|
|
def test_grad_net_pass_dict_including_parameter(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Dict with Parameter as input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
mutable_dict = mutable({"x": tensor_z, "y": tensor_w, "z": parameter_x})
|
|
|
|
|
forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, SCALAR_NUM, mutable_dict, flag_0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCell(nn.Cell):
|
|
|
|
|
def __init__(self, param):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.a = Tensor(np.array([[1, 2], [3, 4]]))
|
|
|
|
|
self.param = param
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.a * self.param * x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradCellWithParameter(nn.Cell):
|
|
|
|
|
def __init__(self, net):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = net
|
|
|
|
|
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
|
|
|
|
|
self.param = self.net.param
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.grad(self.net, self.param)(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradCell(nn.Cell):
|
|
|
|
|
def __init__(self, net):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = net
|
|
|
|
|
self.grad_all = ops.GradOperation(get_all=True)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.grad_all(self.net)(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
|
|
|
|
def test_grad_parameter_input(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Grad with Parameter as input type.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
|
|
|
|
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
|
|
|
|
z = Tensor(np.array([[7, 8], [9, 0]]))
|
|
|
|
|
a = GradCell(TestCell(x))(y)
|
|
|
|
|
b = GradCell(TestCell(x))(z)
|
|
|
|
|
print(f'a: {a}')
|
|
|
|
|
print(f'b: {b}')
|
|
|
|
|
assert np.array_equal(a[0].asnumpy(), b[0].asnumpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# PyNative run error.
|
|
|
|
|
# Support context.PYNATIVE_MODE later.
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
|
|
|
def test_grad_parameter_as_input_and_fv(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Grad with Parameters as input type and fv.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
|
|
|
|
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
|
|
|
|
z = Tensor(np.array([[7, 8], [9, 0]]))
|
|
|
|
|
a = GradCellWithParameter(TestCell(x))(y)
|
|
|
|
|
b = GradCellWithParameter(TestCell(x))(z)
|
|
|
|
|
print(f'a: {a}')
|
|
|
|
|
print(f'b: {b}')
|
|
|
|
|
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
|
|
|
|
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# PyNative run error.
|
|
|
|
|
# Support context.PYNATIVE_MODE later.
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
|
|
|
def test_grad_same_parameter_both_input_and_fv(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Grad with the same Parameter used as input type and fv at the same time.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
|
|
|
|
y = Tensor(np.array([[1, 2], [3, 4]]))
|
|
|
|
|
a = GradCellWithParameter(TestCell(x))(x)
|
|
|
|
|
b = GradCellWithParameter(TestCell(x))(y)
|
|
|
|
|
print(f'a: {a}')
|
|
|
|
|
print(f'b: {b}')
|
|
|
|
|
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
|
|
|
|
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCell2(nn.Cell):
|
|
|
|
|
def __init__(self, param1, param2):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.a = Tensor(np.array([[1, 2], [3, 4]]))
|
|
|
|
|
self.param1 = param1
|
|
|
|
|
self.param2 = param2
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.a * self.param1 * self.param2 * x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradCellWithParameterTuple(nn.Cell):
|
|
|
|
|
def __init__(self, net):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = net
|
|
|
|
|
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
|
|
|
|
|
self.param1 = self.net.param1
|
|
|
|
|
self.param2 = self.net.param2
|
|
|
|
|
self.params = ParameterTuple([self.param1, self.param2])
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.grad(self.net, self.params)(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradCellWithListOfParameter(nn.Cell):
|
|
|
|
|
def __init__(self, net):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = net
|
|
|
|
|
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
|
|
|
|
|
self.param1 = self.net.param1
|
|
|
|
|
self.param2 = self.net.param2
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.grad(self.net, [self.param1, self.param2])(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradCellWithTupleOfParameter(nn.Cell):
|
|
|
|
|
def __init__(self, net):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = net
|
|
|
|
|
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
|
|
|
|
|
self.param1 = self.net.param1
|
|
|
|
|
self.param2 = self.net.param2
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
return self.grad(self.net, [self.param1, self.param2])(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
|
|
|
|
def test_grad_parameter_as_input_and_fv2(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
x1 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x1')
|
|
|
|
|
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
|
|
|
|
|
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
|
|
|
|
z = Tensor(np.array([[7, 8], [9, 0]]))
|
|
|
|
|
a = GradCellWithParameterTuple(TestCell2(x1, x2))(y)
|
|
|
|
|
b = GradCellWithParameterTuple(TestCell2(x1, x2))(z)
|
|
|
|
|
print(f'a: {a}')
|
|
|
|
|
print(f'b: {b}')
|
|
|
|
|
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
|
|
|
|
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
|
|
|
|
|
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')
|
|
|
|
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
|
|
|
|
def test_grad_parameter_list_or_tuple(mode):
|
|
|
|
|
"""
|
|
|
|
|
Feature: Construct()/ms_function input type with back propagate.
|
|
|
|
|
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
|
|
|
|
|
Expectation: No exception.
|
|
|
|
|
"""
|
|
|
|
|
context.set_context(mode=mode)
|
|
|
|
|
x1 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x1')
|
|
|
|
|
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
|
|
|
|
|
y = Tensor(np.array([[7, 8], [9, 0]]))
|
|
|
|
|
# Should not throw exception.
|
|
|
|
|
GradCellWithListOfParameter(TestCell2(x1, x2))(y)
|
|
|
|
|
GradCellWithTupleOfParameter(TestCell2(x1, x2))(y)
|
|
|
|
|