fix jet input check error in graph mode

This commit is contained in:
chenzhuo 2022-05-25 19:18:40 +08:00
parent 906f569a93
commit 4f7c9ffc80
2 changed files with 128 additions and 27 deletions

View File

@ -152,8 +152,7 @@ def _convert_grad_position_type(grad_position):
if isinstance(grad_position, tuple):
for gp in grad_position:
if not isinstance(gp, int):
raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int, "
f"but got {type(gp).__name__}")
raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int.")
if gp < 0:
raise ValueError("The element in grad_position must be >= 0.")
elif isinstance(grad_position, int):
@ -161,8 +160,7 @@ def _convert_grad_position_type(grad_position):
raise ValueError("grad_position must be >= 0.")
grad_position = (grad_position,)
else:
raise TypeError(f"For 'F.grad', the 'grad_position' must be int or tuple, "
f"but got {type(grad_position).__name__}")
raise TypeError(f"For 'F.grad', the 'grad_position' must be int or tuple.")
return grad_position
@ -215,25 +213,22 @@ def grad(fn, grad_position=0, sens_param=False):
return grad_by_position(fn, None, grad_position)
@constexpr
def _trans_jet_inputs(primals_item, series_item):
"""Trans inputs of jet"""
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
raise TypeError(f"For `F.jet`, the elements' types of primals and series must be the same and belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item)} and {dtype(series_item)}.")
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got other dtype.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float32), cast(series_item, mstype.float32)
return primals_item, series_item
@constexpr
def _check_jet_inputs(primals, series):
"""Check inputs of jet"""
if not isinstance(primals, type(series)) or not isinstance(primals, (Tensor, tuple)):
raise TypeError(f"For 'F.jet', the 'primals' and `series` must be both Tensor or tuple, "
f"but got {type(primals).__name__} and {type(series).__name__}.")
if not (isinstance(primals, Tensor) and isinstance(series, Tensor)) and \
not (isinstance(primals, tuple) and isinstance(series, tuple)):
raise TypeError(f"For 'F.jet', the 'primals' and `series` must be both Tensor or tuple.")
if isinstance(primals, Tensor):
if primals.shape == series.shape[1:]:
return _trans_jet_inputs(primals, series)
@ -254,6 +249,11 @@ def _check_jet_inputs(primals, series):
_taylor = _TaylorOperation()
def _preprocess_jet(x, y):
concat_op = P.Concat()
return concat_op((expand_dims(x, 0), y))
def jet(fn, primals, series):
"""
This function is designed to calculate the higher order differentiation of given composite function. To figure out
@ -316,30 +316,30 @@ def jet(fn, primals, series):
"""
primals, series = _check_jet_inputs(primals, series)
derivative_fn = _taylor(fn)
concat_op = P.Concat()
if isinstance(primals, list) and list_len(primals) > 1:
inputs = list(map(lambda x, y: concat_op(((expand_dims(x, 0), y))), primals, series))
inputs = map(_preprocess_jet, primals, series)
outputs = derivative_fn(*inputs)
else:
inputs = concat_op((expand_dims(primals, 0), series))
inputs = _preprocess_jet(primals, series)
outputs = derivative_fn(inputs)
if isinstance(outputs, list) and list_len(outputs) > 1:
out_primals = [element[0] for element in outputs]
out_series = [element[1:] for element in outputs]
if isinstance(outputs, tuple) and tuple_len(outputs) > 1:
out_primals = []
out_series = []
for element in outputs:
out_primals.append(element[0])
out_series.append(element[1:])
else:
out_primals = outputs[0]
out_series = outputs[1:]
return out_primals, out_series
@constexpr
def _trans_derivative_inputs(primals_item):
"""Trans inputs of derivative"""
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
if not dtype(primals_item) in value_type:
raise TypeError(f"For `F.derivative`, the elements of primals must belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item)}.")
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got other dtype.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float32)
return primals_item
@ -349,12 +349,22 @@ def _trans_derivative_inputs(primals_item):
def _check_derivative_order(order):
"""check input order of derivative"""
if not isinstance(order, int):
raise TypeError(f"For `F.derivative`, the type of order must be int, but got {type(order).__name__}.")
raise TypeError(f"For `F.derivative`, the type of order must be int.")
if order < 1:
raise ValueError(f"For `F.derivative`, value of order should not be less than 1, but got {order}.")
return True
def _preprocess_derivate_order_one(x):
concat_op = P.Concat()
return concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x))))
def _preprocess_derivate_order_more(x, order):
concat_op = P.Concat()
return concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x))))
def derivative(fn, primals, order):
"""
This function is designed to calculate the higher order differentiation of given composite function. To figure out
@ -413,11 +423,15 @@ def derivative(fn, primals, order):
series_one = 1
_check_derivative_order(order)
if isinstance(primals, tuple):
trans_primals = [_trans_derivative_inputs(item) for item in primals]
inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals))
trans_primals = map(_trans_derivative_inputs, primals)
inputs = map(_preprocess_derivate_order_one, trans_primals)
if order > 1:
inputs = list(map(lambda x: concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x)))), inputs))
outputs = derivative_fn(*inputs)
processed_inputs = []
for element in inputs:
processed_inputs.append(_preprocess_derivate_order_more(element, order))
outputs = derivative_fn(*processed_inputs)
else:
outputs = derivative_fn(*inputs)
else:
primals = _trans_derivative_inputs(primals)
series = zeros((order,) + primals.shape, dtype(primals))
@ -425,8 +439,11 @@ def derivative(fn, primals, order):
inputs = concat_op((expand_dims(primals, 0), series))
outputs = derivative_fn(inputs)
if isinstance(outputs, tuple) and tuple_len(outputs) > 1:
out_primals = [element[0] for element in outputs]
out_series = [element[-1] for element in outputs]
out_primals = []
out_series = []
for element in outputs:
out_primals.append(element[0])
out_series.append(element[-1])
else:
out_primals = outputs[0]
out_series = outputs[-1]

View File

@ -39,6 +39,18 @@ class MultipleInputSingleOutputNet(nn.Cell):
return out
class MultipleInputMultipleOutputNet(nn.Cell):
def __init__(self):
super(MultipleInputMultipleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
def construct(self, x, y):
out1 = self.sin(x)
out2 = self.cos(y)
return out1, out2
class SingleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(SingleInputSingleOutputNet, self).__init__()
@ -174,3 +186,75 @@ def test_derivative_multiple_input_single_output_graph_mode():
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jet_construct_graph_mode():
"""
Features: Function jet
Description: Test jet in construct with multiple inputs in graph mode.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.net = net
def construct(self, x, y):
res_primals, res_series = jet(self.net, x, y)
return res_primals, res_series
primals = Tensor([2., 2.])
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
net = SingleInputSingleOutputWithScalarNet()
hod_net = Net(net)
expected_primals = np.array([10.328085, 10.328085]).astype(np.float32)
expected_series = np.array([[-3.1220534, -3.1220534], [6.0652323, 6.0652323],
[-18.06463, -18.06463]]).astype(np.float32)
out_primals, out_series = hod_net(primals, series)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_derivative_construct_graph_mode():
"""
Features: Function derivative
Description: Test derivative in construct with multiple inputs in graph mode.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self, net, order):
super(Net, self).__init__()
self.net = net
self.order = order
def construct(self, x, y):
res_primals, res_series = derivative(self.net, (x, y), self.order)
return res_primals, res_series
primals_x = Tensor([1., 1.])
primals_y = Tensor([1., 1.])
net = MultipleInputMultipleOutputNet()
hod_net = Net(net, order=3)
expected_primals_x = np.array([0.841470957, 0.841470957]).astype(np.float32)
expected_primals_y = np.array([0.540302277, 0.540302277]).astype(np.float32)
expected_series_x = np.array([-0.540302277, -0.540302277]).astype(np.float32)
expected_series_y = np.array([0.841470957, 0.841470957]).astype(np.float32)
out_primals, out_series = hod_net(primals_x, primals_y)
assert np.allclose(out_primals[0].asnumpy(), expected_primals_x, atol=1.e-4)
assert np.allclose(out_primals[1].asnumpy(), expected_primals_y, atol=1.e-4)
assert np.allclose(out_series[0].asnumpy(), expected_series_x, atol=1.e-4)
assert np.allclose(out_series[1].asnumpy(), expected_series_y, atol=1.e-4)