intercept tuple and list for print

This commit is contained in:
buxue 2021-03-15 22:09:19 +08:00
parent c0f41deeae
commit 4212666399
1 changed files with 6 additions and 9 deletions

View File

@ -367,15 +367,15 @@ class Print(PrimitiveWithInfer):
Note:
In pynative mode, please use python print function.
In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print,
In graph mode, the bool, int and float would be converted into Tensor to print,
str remains unchanged.
Inputs:
- **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to.
- **input_x** (Union[Tensor, bool, int, float, str]) - The graph node to attach to.
Supports multiple inputs which are separated by ','.
Raises:
TypeError: If `input_x` is not one of the following: Tensor, bool, int, float, str, tuple, list.
TypeError: If `input_x` is not one of the following: Tensor, bool, int, float, str.
Supported Platforms:
``Ascend`` ``GPU``
@ -415,12 +415,9 @@ class Print(PrimitiveWithInfer):
def infer_dtype(self, *inputs):
# check argument types except the last one (io state).
for ele in inputs[:-1]:
if isinstance(ele, (tuple, list)):
self.infer_dtype(*ele)
else:
validator.check_subclass("input", ele,
[mstype.tensor, mstype.int_, mstype.float_, mstype.bool_, mstype.string],
self.name)
validator.check_subclass("input", ele,
[mstype.tensor, mstype.int_, mstype.float_, mstype.bool_, mstype.string],
self.name)
return mstype.int32