forked from mindspore-Ecosystem/mindspore
!49476 FIx Tensor Check
Merge pull request !49476 from jiaoy1224/fixbug
This commit is contained in:
commit
7290e12d69
|
@ -2267,7 +2267,7 @@ def bool_func(*data):
|
|||
@constexpr
|
||||
def cast_to_int(*data):
|
||||
target = data[0]
|
||||
if isinstance(target, Tensor_):
|
||||
if isinstance(target, (Tensor, Tensor_)):
|
||||
target = Tensor(target, internal=True)
|
||||
if len(data) == 1:
|
||||
return int(target)
|
||||
|
@ -2293,7 +2293,7 @@ def int_func(*data):
|
|||
|
||||
@constexpr
|
||||
def cast_to_float(data):
|
||||
if isinstance(data, Tensor_):
|
||||
if isinstance(data, (Tensor, Tensor_)):
|
||||
data = Tensor(data, internal=True)
|
||||
return float(data)
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ def _check_element_type(value):
|
|||
if not _check_element_type(element):
|
||||
return False
|
||||
return True
|
||||
return isinstance(value, (Tensor_, int, float)) and not isinstance(value, bool)
|
||||
return isinstance(value, (Tensor, Tensor_, int, float)) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def mutable(input_data, dynamic_len=False):
|
||||
|
|
|
@ -801,7 +801,7 @@ class CSRTensor(CSRTensor_):
|
|||
[[2.]
|
||||
[1.]]
|
||||
"""
|
||||
validator.check_value_type('dense_vector', dense_vector, (Tensor_,), 'CSRTensor.mv')
|
||||
validator.check_value_type('dense_vector', dense_vector, (Tensor, Tensor_,), 'CSRTensor.mv')
|
||||
return tensor_operator_registry.get("csr_mv")(self, dense_vector)
|
||||
|
||||
def mm(self, matrix: Union[Tensor, CSRTensor]) -> Union[Tensor, CSRTensor]:
|
||||
|
@ -839,7 +839,7 @@ class CSRTensor(CSRTensor_):
|
|||
"""
|
||||
if isinstance(matrix, CSRTensor):
|
||||
return tensor_operator_registry.get("csr_mm")(self, matrix)
|
||||
validator.check_value_type('matrix', matrix, (Tensor_,), 'CSRTensor.mm')
|
||||
validator.check_value_type('matrix', matrix, (Tensor, Tensor_,), 'CSRTensor.mm')
|
||||
return tensor_operator_registry.get("csr_mm_akg")()(self.indptr, self.indices, self.values,
|
||||
self.shape, matrix)
|
||||
|
||||
|
|
|
@ -2814,7 +2814,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|||
For details, please refer to :func:`mindspore.ops.gather_nd`.
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_value_type('indices', indices, (Tensor_,), 'Tensor.gather_nd')
|
||||
validator.check_value_type('indices', indices, (Tensor, Tensor_,), 'Tensor.gather_nd')
|
||||
return tensor_operator_registry.get('gather_nd')(self, indices)
|
||||
|
||||
def gather(self, input_indices, axis):
|
||||
|
@ -3125,7 +3125,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|||
For details, please refer to :func:`mindspore.ops.gather_elements`.
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_value_type('index', index, (Tensor_,), 'Tensor.gather_elements')
|
||||
validator.check_value_type('index', index, (Tensor, Tensor_,), 'Tensor.gather_elements')
|
||||
return tensor_operator_registry.get('gather_elements')(self, dim, index)
|
||||
|
||||
def nonzero(self):
|
||||
|
|
|
@ -1569,7 +1569,7 @@ class NdGrid:
|
|||
if self.sparse:
|
||||
return grids
|
||||
|
||||
if isinstance(grids, Tensor_):
|
||||
if isinstance(grids, (Tensor, Tensor_)):
|
||||
return grids
|
||||
expanded = []
|
||||
for grid in grids:
|
||||
|
|
|
@ -637,7 +637,7 @@ class Reshape(PrimitiveWithCheck):
|
|||
# for shape is not constant
|
||||
if shape is None or x is None:
|
||||
return None
|
||||
if isinstance(shape, Tensor_):
|
||||
if isinstance(shape, (Tensor, Tensor_)):
|
||||
validator.check_tensor_dtype_valid("shape", mstype.tensor_type(shape.dtype),
|
||||
[mstype.int32, mstype.int64], self.name)
|
||||
shape = shape.asnumpy().tolist()
|
||||
|
@ -1573,9 +1573,9 @@ class FillV2(PrimitiveWithCheck):
|
|||
self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y'])
|
||||
|
||||
def infer_value(self, dims, x):
|
||||
if isinstance(dims, Tensor_):
|
||||
if isinstance(dims, (Tensor, Tensor_)):
|
||||
dims = dims.asnumpy()
|
||||
if isinstance(x, Tensor_):
|
||||
if isinstance(x, (Tensor, Tensor_)):
|
||||
x = x.asnumpy()
|
||||
if dims is not None and None not in dims and x is not None:
|
||||
ret = np.full(dims, x)
|
||||
|
@ -3521,7 +3521,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
}
|
||||
return slices, slice_shape[0]
|
||||
|
||||
if isinstance(slice_value, Tensor_):
|
||||
if isinstance(slice_value, (Tensor, Tensor_)):
|
||||
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
|
||||
slice_value = slice_value.asnumpy().tolist()
|
||||
elif not isinstance(slice_value, tuple):
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore._checkparam import Rel
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
||||
|
||||
|
||||
|
@ -494,7 +495,7 @@ class Print(Primitive):
|
|||
for arg in args:
|
||||
if isinstance(arg, Parameter):
|
||||
print(Tensor_.__repr__(arg))
|
||||
elif isinstance(arg, Tensor_):
|
||||
elif isinstance(arg, (Tensor, Tensor_)):
|
||||
print(arg.__repr__())
|
||||
else:
|
||||
print(arg)
|
||||
|
|
|
@ -25,7 +25,9 @@ from collections import defaultdict
|
|||
from mindspore import log as logger
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import context
|
||||
from mindspore._c_expression import Tensor, security
|
||||
from mindspore._c_expression import security
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.train._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
|
||||
|
@ -75,7 +77,7 @@ def _record_summary_tensor_data():
|
|||
"""Checks the tag is valid for summary."""
|
||||
if not isinstance(tag, str) or not tag:
|
||||
raise ValueError(f'For "{summary_name}", the name must be valid string, but got "{tag}".')
|
||||
if not isinstance(tensor, Tensor):
|
||||
if not isinstance(tensor, (Tensor, Tensor_)):
|
||||
raise TypeError(f'For "{summary_name}", the parameter "value" expect to be Tensor, '
|
||||
f'but got {type(tensor).__name__}')
|
||||
|
||||
|
@ -318,7 +320,7 @@ class SummaryRecord:
|
|||
if not name or not isinstance(name, str):
|
||||
raise ValueError(f'For "{self.__class__.__name__}", the parameter "name" type should be str, '
|
||||
f'but got {type(name)}.')
|
||||
if not isinstance(value, Tensor):
|
||||
if not isinstance(value, (Tensor, Tensor_)):
|
||||
raise TypeError(f'For "{self.__class__.__name__}", the parameter "value" expect to be Tensor, '
|
||||
f'but got {type(value).__name__}')
|
||||
np_value = _check_to_numpy(plugin, value)
|
||||
|
|
Loading…
Reference in New Issue