forked from mindspore-Ecosystem/mindspore
fix infer value bug
This commit is contained in:
parent
3f8a7920d5
commit
e4b3b72ebf
|
@ -22,6 +22,10 @@ from . import dtype as mstype
|
|||
from ._register_for_tensor import tensor_operator_registry
|
||||
|
||||
__all__ = ['Tensor', 'MetaTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, np.bool_)
|
||||
|
||||
|
||||
|
||||
class Tensor(Tensor_):
|
||||
|
@ -54,6 +58,10 @@ class Tensor(Tensor_):
|
|||
"""
|
||||
|
||||
def __init__(self, input_data, dtype=None):
|
||||
# If input data is numpy number, convert it to np array
|
||||
if isinstance(input_data, np_types):
|
||||
input_data = np.array(input_data)
|
||||
|
||||
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
|
||||
check_type('tensor input_data', input_data, (Tensor_, float, int))
|
||||
if dtype is not None:
|
||||
|
|
|
@ -830,7 +830,8 @@ class Neg(PrimitiveWithInfer):
|
|||
def infer_value(self, input_x):
|
||||
if input_x is not None:
|
||||
input_x = input_x.asnumpy()
|
||||
return Tensor(-input_x)
|
||||
out = np.array(-input_x, input_x.dtype)
|
||||
return Tensor(out)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1609,7 +1610,8 @@ class Div(_MathBinaryOp):
|
|||
if x is not None and y is not None:
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
return Tensor(x / y)
|
||||
out = np.array(x / y, x.dtype)
|
||||
return Tensor(out)
|
||||
return None
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue