From e4b3b72ebfd2b02df35bdcfddae6e507632a0d2c Mon Sep 17 00:00:00 2001 From: geekun Date: Tue, 23 Jun 2020 16:29:22 +0800 Subject: [PATCH] fix infer value bug --- mindspore/common/tensor.py | 8 ++++++++ mindspore/ops/operations/math_ops.py | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 0a631b954fd..ddd7cbfabc0 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -22,6 +22,10 @@ from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry __all__ = ['Tensor', 'MetaTensor'] +np_types = (np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, np.float16, + np.float32, np.float64, np.bool_) + class Tensor(Tensor_): @@ -54,6 +58,10 @@ class Tensor(Tensor_): """ def __init__(self, input_data, dtype=None): + # If input data is numpy number, convert it to np array + if isinstance(input_data, np_types): + input_data = np.array(input_data) + # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. check_type('tensor input_data', input_data, (Tensor_, float, int)) if dtype is not None: diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index d992d078ace..d5a83f1208b 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -830,7 +830,8 @@ class Neg(PrimitiveWithInfer): def infer_value(self, input_x): if input_x is not None: input_x = input_x.asnumpy() - return Tensor(-input_x) + out = np.array(-input_x, input_x.dtype) + return Tensor(out) return None @@ -1609,7 +1610,8 @@ class Div(_MathBinaryOp): if x is not None and y is not None: x = x.asnumpy() y = y.asnumpy() - return Tensor(x / y) + out = np.array(x / y, x.dtype) + return Tensor(out) return None