!49476 FIx Tensor Check

Merge pull request !49476 from jiaoy1224/fixbug
This commit is contained in:
i-robot 2023-03-02 10:32:27 +00:00 committed by Gitee
commit 7290e12d69
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 19 additions and 16 deletions

View File

@ -2267,7 +2267,7 @@ def bool_func(*data):
@constexpr
def cast_to_int(*data):
target = data[0]
if isinstance(target, Tensor_):
if isinstance(target, (Tensor, Tensor_)):
target = Tensor(target, internal=True)
if len(data) == 1:
return int(target)
@ -2293,7 +2293,7 @@ def int_func(*data):
@constexpr
def cast_to_float(data):
if isinstance(data, Tensor_):
if isinstance(data, (Tensor, Tensor_)):
data = Tensor(data, internal=True)
return float(data)

View File

@ -51,7 +51,7 @@ def _check_element_type(value):
if not _check_element_type(element):
return False
return True
return isinstance(value, (Tensor_, int, float)) and not isinstance(value, bool)
return isinstance(value, (Tensor, Tensor_, int, float)) and not isinstance(value, bool)
def mutable(input_data, dynamic_len=False):

View File

@ -801,7 +801,7 @@ class CSRTensor(CSRTensor_):
[[2.]
[1.]]
"""
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
validator.check_value_type('dense_vector', dense_vector, (Tensor, Tensor_,), 'CSRTensor.mv')
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
def mm(self, matrix: Union[Tensor, CSRTensor]) -> Union[Tensor, CSRTensor]:
@ -839,7 +839,7 @@ class CSRTensor(CSRTensor_):
"""
if isinstance(matrix, CSRTensor):
return tensor_operator_registry.get("csr_mm")(self, matrix)
validator.check_value_type('matrix', matrix, (Tensor_,), 'CSRTensor.mm')
validator.check_value_type('matrix', matrix, (Tensor, Tensor_,), 'CSRTensor.mm')
return tensor_operator_registry.get("csr_mm_akg")()(self.indptr, self.indices, self.values,
self.shape, matrix)

View File

@ -2814,7 +2814,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
For details, please refer to :func:`mindspore.ops.gather_nd`.
"""
self._init_check()
validator.check_value_type('indices', indices, (Tensor_,), 'Tensor.gather_nd')
validator.check_value_type('indices', indices, (Tensor, Tensor_,), 'Tensor.gather_nd')
return tensor_operator_registry.get('gather_nd')(self, indices)
def gather(self, input_indices, axis):
@ -3125,7 +3125,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
For details, please refer to :func:`mindspore.ops.gather_elements`.
"""
self._init_check()
validator.check_value_type('index', index, (Tensor_,), 'Tensor.gather_elements')
validator.check_value_type('index', index, (Tensor, Tensor_,), 'Tensor.gather_elements')
return tensor_operator_registry.get('gather_elements')(self, dim, index)
def nonzero(self):

View File

@ -1569,7 +1569,7 @@ class NdGrid:
if self.sparse:
return grids
if isinstance(grids, Tensor_):
if isinstance(grids, (Tensor, Tensor_)):
return grids
expanded = []
for grid in grids:

View File

@ -637,7 +637,7 @@ class Reshape(PrimitiveWithCheck):
# for shape is not constant
if shape is None or x is None:
return None
if isinstance(shape, Tensor_):
if isinstance(shape, (Tensor, Tensor_)):
validator.check_tensor_dtype_valid("shape", mstype.tensor_type(shape.dtype),
[mstype.int32, mstype.int64], self.name)
shape = shape.asnumpy().tolist()
@ -1573,9 +1573,9 @@ class FillV2(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y'])
def infer_value(self, dims, x):
if isinstance(dims, Tensor_):
if isinstance(dims, (Tensor, Tensor_)):
dims = dims.asnumpy()
if isinstance(x, Tensor_):
if isinstance(x, (Tensor, Tensor_)):
x = x.asnumpy()
if dims is not None and None not in dims and x is not None:
ret = np.full(dims, x)
@ -3521,7 +3521,7 @@ class StridedSlice(PrimitiveWithInfer):
}
return slices, slice_shape[0]
if isinstance(slice_value, Tensor_):
if isinstance(slice_value, (Tensor, Tensor_)):
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
slice_value = slice_value.asnumpy().tolist()
elif not isinstance(slice_value, tuple):

View File

@ -22,6 +22,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops.primitive import prim_attr_register, Primitive, PrimitiveWithInfer
@ -494,7 +495,7 @@ class Print(Primitive):
for arg in args:
if isinstance(arg, Parameter):
print(Tensor_.__repr__(arg))
elif isinstance(arg, Tensor_):
elif isinstance(arg, (Tensor, Tensor_)):
print(arg.__repr__())
else:
print(arg)

View File

@ -25,7 +25,9 @@ from collections import defaultdict
from mindspore import log as logger
from mindspore.nn import Cell
from mindspore import context
from mindspore._c_expression import Tensor, security
from mindspore._c_expression import security
from mindspore._c_expression import Tensor as Tensor_
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator
from mindspore.common.api import _cell_graph_executor
from mindspore.train._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
@ -75,7 +77,7 @@ def _record_summary_tensor_data():
"""Checks the tag is valid for summary."""
if not isinstance(tag, str) or not tag:
raise ValueError(f'For "{summary_name}", the name must be valid string, but got "{tag}".')
if not isinstance(tensor, Tensor):
if not isinstance(tensor, (Tensor, Tensor_)):
raise TypeError(f'For "{summary_name}", the parameter "value" expect to be Tensor, '
f'but got {type(tensor).__name__}')
@ -318,7 +320,7 @@ class SummaryRecord:
if not name or not isinstance(name, str):
raise ValueError(f'For "{self.__class__.__name__}", the parameter "name" type should be str, '
f'but got {type(name)}.')
if not isinstance(value, Tensor):
if not isinstance(value, (Tensor, Tensor_)):
raise TypeError(f'For "{self.__class__.__name__}", the parameter "value" expect to be Tensor, '
f'but got {type(value).__name__}')
np_value = _check_to_numpy(plugin, value)