modify _check_compile_dynamic_shape

This commit is contained in:
changzherui 2022-09-08 01:42:21 +08:00
parent 3b68297661
commit 637eb20bde
3 changed files with 23 additions and 41 deletions

View File

@ -1,5 +1,5 @@
mindspore.conver_model
======================
mindspore.convert_model
=======================
.. py:function:: mindspore.convert_model(mindir_file, convert_file, file_format)

View File

@ -228,7 +228,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>) {
(void)str_tensor.SetData(v);
return str_tensor;
} else {
MS_LOG(WARNING) << "Unsupported value type: " << value->type_name()
MS_LOG(INFO) << "Unsupported value type: " << value->type_name()
<< " to convert to tensor. Value: " << value->ToString();
}
return GeTensor();

View File

@ -2175,44 +2175,26 @@ class Cell(Cell_):
Args:
inputs (tuple): Inputs of the Cell object.
"""
len_inputs = len(inputs)
len_dynamic_shape_inputs = len(self._dynamic_shape_inputs)
if len_dynamic_shape_inputs != len_inputs:
raise ValueError(
f"For 'set_inputs', the Length of Tensor must be {len_inputs}, but got {len_dynamic_shape_inputs}."
)
for tensor_index in range(len_dynamic_shape_inputs):
i_dynamic_shape_inputs = self._dynamic_shape_inputs[tensor_index]
i_inputs = inputs[tensor_index]
if isinstance(i_dynamic_shape_inputs, Tensor):
if not isinstance(i_inputs, Tensor):
if len(self._dynamic_shape_inputs) != len(inputs):
raise ValueError("The number of 'set_input' Tensor must be equal to network's inputs.")
for net_input, set_input in zip(inputs, self._dynamic_shape_inputs):
if isinstance(set_input, Tensor):
if not isinstance(net_input, Tensor):
raise TypeError(
f"When using 'set_inputs', the type of network inputs and set_inputs must be the same. But "
f"got {type(i_inputs)} and {type(i_dynamic_shape_inputs)}."
)
if i_dynamic_shape_inputs.dtype is not i_inputs.dtype:
raise TypeError(
f"For 'set_inputs', the DataType of Tensor must be {i_inputs.dtype}, but got "
f"{i_dynamic_shape_inputs.dtype}."
)
set_inputs_shape = list(i_dynamic_shape_inputs.shape)
if i_inputs.shape == ():
inputs_shape = i_inputs
else:
inputs_shape = list(i_inputs.shape)
if len(inputs_shape) != len(set_inputs_shape):
f"The 'set_inputs' type must be the same as network's input, "
f"but got {type(set_input)} and {type(net_input)}.")
if set_input.dtype is not net_input.dtype:
raise ValueError(
f"For 'set_inputs' the Dimension of Tensor shape must be {len(inputs_shape)}, but got "
f"{len(set_inputs_shape)}."
)
for shape_index in i_dynamic_shape_inputs.shape:
if shape_index != -1:
dynamic_index = i_dynamic_shape_inputs.shape.index(shape_index)
if set_inputs_shape[dynamic_index] != inputs_shape[dynamic_index]:
f"For 'set_inputs' the type of Tensor must be the same as network's input, "
f"but got {set_input.dtype()} and {net_input.dtype()}.")
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
raise ValueError(
f"For 'Length of Tensor shape', the value must be the same with that of inputs, but"
f" got {i_dynamic_shape_inputs.shape}."
)
f"For 'set_inputs' the dims of Tensor must be the same as network's input, "
f"but got {set_input.dim()} and {net_input.dim()}.")
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
raise ValueError(
f"For 'set_inputs' the shape of Tensor must be the same as network's input, "
f"but got {set_input.shape} and {net_input.shape}.")
class GraphCell(Cell):