forked from mindspore-Ecosystem/mindspore
modify _check_compile_dynamic_shape
This commit is contained in:
parent
3b68297661
commit
637eb20bde
|
@ -1,5 +1,5 @@
|
|||
mindspore.conver_model
|
||||
======================
|
||||
mindspore.convert_model
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.convert_model(mindir_file, convert_file, file_format)
|
||||
|
||||
|
|
|
@ -228,7 +228,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>) {
|
|||
(void)str_tensor.SetData(v);
|
||||
return str_tensor;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Unsupported value type: " << value->type_name()
|
||||
MS_LOG(INFO) << "Unsupported value type: " << value->type_name()
|
||||
<< " to convert to tensor. Value: " << value->ToString();
|
||||
}
|
||||
return GeTensor();
|
||||
|
|
|
@ -2175,44 +2175,26 @@ class Cell(Cell_):
|
|||
Args:
|
||||
inputs (tuple): Inputs of the Cell object.
|
||||
"""
|
||||
len_inputs = len(inputs)
|
||||
len_dynamic_shape_inputs = len(self._dynamic_shape_inputs)
|
||||
if len_dynamic_shape_inputs != len_inputs:
|
||||
raise ValueError(
|
||||
f"For 'set_inputs', the Length of Tensor must be {len_inputs}, but got {len_dynamic_shape_inputs}."
|
||||
)
|
||||
for tensor_index in range(len_dynamic_shape_inputs):
|
||||
i_dynamic_shape_inputs = self._dynamic_shape_inputs[tensor_index]
|
||||
i_inputs = inputs[tensor_index]
|
||||
if isinstance(i_dynamic_shape_inputs, Tensor):
|
||||
if not isinstance(i_inputs, Tensor):
|
||||
if len(self._dynamic_shape_inputs) != len(inputs):
|
||||
raise ValueError("The number of 'set_input' Tensor must be equal to network's inputs.")
|
||||
for net_input, set_input in zip(inputs, self._dynamic_shape_inputs):
|
||||
if isinstance(set_input, Tensor):
|
||||
if not isinstance(net_input, Tensor):
|
||||
raise TypeError(
|
||||
f"When using 'set_inputs', the type of network inputs and set_inputs must be the same. But "
|
||||
f"got {type(i_inputs)} and {type(i_dynamic_shape_inputs)}."
|
||||
)
|
||||
if i_dynamic_shape_inputs.dtype is not i_inputs.dtype:
|
||||
raise TypeError(
|
||||
f"For 'set_inputs', the DataType of Tensor must be {i_inputs.dtype}, but got "
|
||||
f"{i_dynamic_shape_inputs.dtype}."
|
||||
)
|
||||
set_inputs_shape = list(i_dynamic_shape_inputs.shape)
|
||||
if i_inputs.shape == ():
|
||||
inputs_shape = i_inputs
|
||||
else:
|
||||
inputs_shape = list(i_inputs.shape)
|
||||
if len(inputs_shape) != len(set_inputs_shape):
|
||||
f"The 'set_inputs' type must be the same as network's input, "
|
||||
f"but got {type(set_input)} and {type(net_input)}.")
|
||||
if set_input.dtype is not net_input.dtype:
|
||||
raise ValueError(
|
||||
f"For 'set_inputs' the Dimension of Tensor shape must be {len(inputs_shape)}, but got "
|
||||
f"{len(set_inputs_shape)}."
|
||||
)
|
||||
for shape_index in i_dynamic_shape_inputs.shape:
|
||||
if shape_index != -1:
|
||||
dynamic_index = i_dynamic_shape_inputs.shape.index(shape_index)
|
||||
if set_inputs_shape[dynamic_index] != inputs_shape[dynamic_index]:
|
||||
f"For 'set_inputs' the type of Tensor must be the same as network's input, "
|
||||
f"but got {set_input.dtype()} and {net_input.dtype()}.")
|
||||
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
||||
raise ValueError(
|
||||
f"For 'Length of Tensor shape', the value must be the same with that of inputs, but"
|
||||
f" got {i_dynamic_shape_inputs.shape}."
|
||||
)
|
||||
f"For 'set_inputs' the dims of Tensor must be the same as network's input, "
|
||||
f"but got {set_input.dim()} and {net_input.dim()}.")
|
||||
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
||||
raise ValueError(
|
||||
f"For 'set_inputs' the shape of Tensor must be the same as network's input, "
|
||||
f"but got {set_input.shape} and {net_input.shape}.")
|
||||
|
||||
|
||||
class GraphCell(Cell):
|
||||
|
|
Loading…
Reference in New Issue