forked from mindspore-Ecosystem/mindspore
fix jet input check error in graph mode
This commit is contained in:
parent
906f569a93
commit
4f7c9ffc80
|
@ -152,8 +152,7 @@ def _convert_grad_position_type(grad_position):
|
|||
if isinstance(grad_position, tuple):
|
||||
for gp in grad_position:
|
||||
if not isinstance(gp, int):
|
||||
raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int, "
|
||||
f"but got {type(gp).__name__}")
|
||||
raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int.")
|
||||
if gp < 0:
|
||||
raise ValueError("The element in grad_position must be >= 0.")
|
||||
elif isinstance(grad_position, int):
|
||||
|
@ -161,8 +160,7 @@ def _convert_grad_position_type(grad_position):
|
|||
raise ValueError("grad_position must be >= 0.")
|
||||
grad_position = (grad_position,)
|
||||
else:
|
||||
raise TypeError(f"For 'F.grad', the 'grad_position' must be int or tuple, "
|
||||
f"but got {type(grad_position).__name__}")
|
||||
raise TypeError(f"For 'F.grad', the 'grad_position' must be int or tuple.")
|
||||
return grad_position
|
||||
|
||||
|
||||
|
@ -215,25 +213,22 @@ def grad(fn, grad_position=0, sens_param=False):
|
|||
return grad_by_position(fn, None, grad_position)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _trans_jet_inputs(primals_item, series_item):
|
||||
"""Trans inputs of jet"""
|
||||
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
|
||||
if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
|
||||
raise TypeError(f"For `F.jet`, the elements' types of primals and series must be the same and belong to "
|
||||
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
|
||||
f" {dtype(primals_item)} and {dtype(series_item)}.")
|
||||
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got other dtype.")
|
||||
if dtype(primals_item) in [mstype.int32, mstype.int64]:
|
||||
return cast(primals_item, mstype.float32), cast(series_item, mstype.float32)
|
||||
return primals_item, series_item
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_jet_inputs(primals, series):
|
||||
"""Check inputs of jet"""
|
||||
if not isinstance(primals, type(series)) or not isinstance(primals, (Tensor, tuple)):
|
||||
raise TypeError(f"For 'F.jet', the 'primals' and `series` must be both Tensor or tuple, "
|
||||
f"but got {type(primals).__name__} and {type(series).__name__}.")
|
||||
if not (isinstance(primals, Tensor) and isinstance(series, Tensor)) and \
|
||||
not (isinstance(primals, tuple) and isinstance(series, tuple)):
|
||||
raise TypeError(f"For 'F.jet', the 'primals' and `series` must be both Tensor or tuple.")
|
||||
if isinstance(primals, Tensor):
|
||||
if primals.shape == series.shape[1:]:
|
||||
return _trans_jet_inputs(primals, series)
|
||||
|
@ -254,6 +249,11 @@ def _check_jet_inputs(primals, series):
|
|||
_taylor = _TaylorOperation()
|
||||
|
||||
|
||||
def _preprocess_jet(x, y):
|
||||
concat_op = P.Concat()
|
||||
return concat_op((expand_dims(x, 0), y))
|
||||
|
||||
|
||||
def jet(fn, primals, series):
|
||||
"""
|
||||
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
||||
|
@ -316,30 +316,30 @@ def jet(fn, primals, series):
|
|||
"""
|
||||
primals, series = _check_jet_inputs(primals, series)
|
||||
derivative_fn = _taylor(fn)
|
||||
concat_op = P.Concat()
|
||||
if isinstance(primals, list) and list_len(primals) > 1:
|
||||
inputs = list(map(lambda x, y: concat_op(((expand_dims(x, 0), y))), primals, series))
|
||||
inputs = map(_preprocess_jet, primals, series)
|
||||
outputs = derivative_fn(*inputs)
|
||||
else:
|
||||
inputs = concat_op((expand_dims(primals, 0), series))
|
||||
inputs = _preprocess_jet(primals, series)
|
||||
outputs = derivative_fn(inputs)
|
||||
if isinstance(outputs, list) and list_len(outputs) > 1:
|
||||
out_primals = [element[0] for element in outputs]
|
||||
out_series = [element[1:] for element in outputs]
|
||||
if isinstance(outputs, tuple) and tuple_len(outputs) > 1:
|
||||
out_primals = []
|
||||
out_series = []
|
||||
for element in outputs:
|
||||
out_primals.append(element[0])
|
||||
out_series.append(element[1:])
|
||||
else:
|
||||
out_primals = outputs[0]
|
||||
out_series = outputs[1:]
|
||||
return out_primals, out_series
|
||||
|
||||
|
||||
@constexpr
|
||||
def _trans_derivative_inputs(primals_item):
|
||||
"""Trans inputs of derivative"""
|
||||
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
|
||||
if not dtype(primals_item) in value_type:
|
||||
raise TypeError(f"For `F.derivative`, the elements of primals must belong to "
|
||||
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
|
||||
f" {dtype(primals_item)}.")
|
||||
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got other dtype.")
|
||||
if dtype(primals_item) in [mstype.int32, mstype.int64]:
|
||||
return cast(primals_item, mstype.float32)
|
||||
return primals_item
|
||||
|
@ -349,12 +349,22 @@ def _trans_derivative_inputs(primals_item):
|
|||
def _check_derivative_order(order):
|
||||
"""check input order of derivative"""
|
||||
if not isinstance(order, int):
|
||||
raise TypeError(f"For `F.derivative`, the type of order must be int, but got {type(order).__name__}.")
|
||||
raise TypeError(f"For `F.derivative`, the type of order must be int.")
|
||||
if order < 1:
|
||||
raise ValueError(f"For `F.derivative`, value of order should not be less than 1, but got {order}.")
|
||||
return True
|
||||
|
||||
|
||||
def _preprocess_derivate_order_one(x):
|
||||
concat_op = P.Concat()
|
||||
return concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x))))
|
||||
|
||||
|
||||
def _preprocess_derivate_order_more(x, order):
|
||||
concat_op = P.Concat()
|
||||
return concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x))))
|
||||
|
||||
|
||||
def derivative(fn, primals, order):
|
||||
"""
|
||||
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
||||
|
@ -413,11 +423,15 @@ def derivative(fn, primals, order):
|
|||
series_one = 1
|
||||
_check_derivative_order(order)
|
||||
if isinstance(primals, tuple):
|
||||
trans_primals = [_trans_derivative_inputs(item) for item in primals]
|
||||
inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals))
|
||||
trans_primals = map(_trans_derivative_inputs, primals)
|
||||
inputs = map(_preprocess_derivate_order_one, trans_primals)
|
||||
if order > 1:
|
||||
inputs = list(map(lambda x: concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x)))), inputs))
|
||||
outputs = derivative_fn(*inputs)
|
||||
processed_inputs = []
|
||||
for element in inputs:
|
||||
processed_inputs.append(_preprocess_derivate_order_more(element, order))
|
||||
outputs = derivative_fn(*processed_inputs)
|
||||
else:
|
||||
outputs = derivative_fn(*inputs)
|
||||
else:
|
||||
primals = _trans_derivative_inputs(primals)
|
||||
series = zeros((order,) + primals.shape, dtype(primals))
|
||||
|
@ -425,8 +439,11 @@ def derivative(fn, primals, order):
|
|||
inputs = concat_op((expand_dims(primals, 0), series))
|
||||
outputs = derivative_fn(inputs)
|
||||
if isinstance(outputs, tuple) and tuple_len(outputs) > 1:
|
||||
out_primals = [element[0] for element in outputs]
|
||||
out_series = [element[-1] for element in outputs]
|
||||
out_primals = []
|
||||
out_series = []
|
||||
for element in outputs:
|
||||
out_primals.append(element[0])
|
||||
out_series.append(element[-1])
|
||||
else:
|
||||
out_primals = outputs[0]
|
||||
out_series = outputs[-1]
|
||||
|
|
|
@ -39,6 +39,18 @@ class MultipleInputSingleOutputNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MultipleInputMultipleOutputNet, self).__init__()
|
||||
self.sin = P.Sin()
|
||||
self.cos = P.Cos()
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.sin(x)
|
||||
out2 = self.cos(y)
|
||||
return out1, out2
|
||||
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SingleInputSingleOutputNet, self).__init__()
|
||||
|
@ -174,3 +186,75 @@ def test_derivative_multiple_input_single_output_graph_mode():
|
|||
out_primals, out_series = derivative(net, primals, order)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jet_construct_graph_mode():
|
||||
"""
|
||||
Features: Function jet
|
||||
Description: Test jet in construct with multiple inputs in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
res_primals, res_series = jet(self.net, x, y)
|
||||
return res_primals, res_series
|
||||
|
||||
primals = Tensor([2., 2.])
|
||||
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
|
||||
net = SingleInputSingleOutputWithScalarNet()
|
||||
hod_net = Net(net)
|
||||
expected_primals = np.array([10.328085, 10.328085]).astype(np.float32)
|
||||
expected_series = np.array([[-3.1220534, -3.1220534], [6.0652323, 6.0652323],
|
||||
[-18.06463, -18.06463]]).astype(np.float32)
|
||||
out_primals, out_series = hod_net(primals, series)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_derivative_construct_graph_mode():
|
||||
"""
|
||||
Features: Function derivative
|
||||
Description: Test derivative in construct with multiple inputs in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net, order):
|
||||
super(Net, self).__init__()
|
||||
self.net = net
|
||||
self.order = order
|
||||
|
||||
def construct(self, x, y):
|
||||
res_primals, res_series = derivative(self.net, (x, y), self.order)
|
||||
return res_primals, res_series
|
||||
|
||||
primals_x = Tensor([1., 1.])
|
||||
primals_y = Tensor([1., 1.])
|
||||
net = MultipleInputMultipleOutputNet()
|
||||
hod_net = Net(net, order=3)
|
||||
expected_primals_x = np.array([0.841470957, 0.841470957]).astype(np.float32)
|
||||
expected_primals_y = np.array([0.540302277, 0.540302277]).astype(np.float32)
|
||||
expected_series_x = np.array([-0.540302277, -0.540302277]).astype(np.float32)
|
||||
expected_series_y = np.array([0.841470957, 0.841470957]).astype(np.float32)
|
||||
out_primals, out_series = hod_net(primals_x, primals_y)
|
||||
assert np.allclose(out_primals[0].asnumpy(), expected_primals_x, atol=1.e-4)
|
||||
assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4)
|
||||
assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4)
|
||||
assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4)
|
||||
|
|
Loading…
Reference in New Issue