diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 69b0e9dcad0..16af27f8555 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -13,10 +13,8 @@ # limitations under the License. # ============================================================================ """math""" -import math import numpy as np from mindspore.ops import operations as P -from mindspore.ops.operations import _inner_ops as inner from mindspore.common.tensor import Tensor from mindspore.common._decorator import deprecated from mindspore.ops.primitive import constexpr @@ -25,7 +23,6 @@ from ..cell import Cell from ...common import dtype as mstype from ..._checkparam import Validator as validator - __all__ = ['ReduceLogSumExp', 'Range', 'LGamma', @@ -140,37 +137,15 @@ class Range(Cell): def __init__(self, start, limit=None, delta=1): super(Range, self).__init__() - validator.check_value_type("start", start, [int, float], self.cls_name) - validator.check_value_type("delta", delta, [int, float], self.cls_name) - if delta == 0: - raise ValueError("The input of `delta` can not be equal to zero.") - if limit is not None: - validator.check_value_type("limit", limit, [int, float], self.cls_name) - if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int): - self.dtype = mstype.int32 - else: - self.dtype = mstype.float32 + data = np.arange(start, limit, delta) + if data.dtype == np.float: + self.ms_dtype = mstype.float32 else: - if isinstance(start, int) and isinstance(delta, int): - self.dtype = mstype.int32 - else: - self.dtype = mstype.float32 - if isinstance(start, int): - start = float(start) - if isinstance(limit, int): - limit = float(limit) - if isinstance(delta, int): - delta = float(delta) - self.range_x = inner.Range(start, limit, delta) - if limit is None: - length_input = math.ceil(start / delta) - else: - length_input = math.ceil((limit - start) / delta) - self.input_tensor = Tensor(list(range(length_input)), self.dtype) + self.ms_dtype = mstype.int32 + self.result_tensor = Tensor(data, dtype=self.ms_dtype) def construct(self): - range_out = self.range_x(self.input_tensor) - return range_out + return self.result_tensor class LGamma(Cell): diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 68e18d7eeeb..c2c2ddd8a53 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -16,6 +16,7 @@ """Inner operators.""" import numpy as np +from mindspore.common import Tensor from ..._checkparam import Rel from ..._checkparam import Validator as validator from ... import context @@ -25,6 +26,7 @@ from ..operations.math_ops import _infer_shape_reduce from ...communication.management import GlobalComm from .. import signature as sig + class ExtractImagePatches(PrimitiveWithInfer): """ Extracts patches from images. @@ -164,6 +166,9 @@ class Range(PrimitiveWithInfer): validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) return x_dtype + def infer_value(self, x_value): + return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype) + class Quant(PrimitiveWithInfer): r""" @@ -408,6 +413,7 @@ class Send(PrimitiveWithInfer): >>> net = Net() >>> output = net(input_) """ + @prim_attr_register def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): self.rank = dest_rank @@ -464,6 +470,7 @@ class Receive(PrimitiveWithInfer): >>> net = Net() >>> output = net() """ + @prim_attr_register def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): self.rank = src_rank diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 0d2cd3de439..911918c917a 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2391,6 +2391,7 @@ class Pack(PrimitiveWithInfer): Same as operator Stack. Pack will be deprecated in the future. Please use Stack instead. """ + @deprecated("1.1", "Stack", True) @prim_attr_register def __init__(self, axis=0): @@ -2469,6 +2470,7 @@ class Unpack(PrimitiveWithInfer): Same as operator Unstack. Unpack will be deprecated in the future. Please use Unstack instead. """ + @deprecated("1.1", "Unstack", True) @prim_attr_register def __init__(self, axis=0): @@ -3491,7 +3493,6 @@ class ScatterUpdate(_ScatterOp_Dynamic): self.add_prim_attr('side_effect_mem', True) - class ScatterNdUpdate(_ScatterNdOp): r""" Updates tensor values by using input indices and value. @@ -5250,3 +5251,11 @@ class Range(PrimitiveWithCheck): valid_dtypes = [mstype.int32, mstype.float32] inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype} validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name) + + def infer_value(self, start_value, limit_value, delat_value): + if start_value is not None and limit_value is not None and delat_value is not None: + start = np.asscalar(start_value.asnumpy()) + limit = np.asscalar(limit_value.asnumpy()) + delat = np.asscalar(delat_value.asnumpy()) + return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype) + return None diff --git a/tests/ut/python/parallel/test_range.py b/tests/ut/python/parallel/test_range.py index 2e8780d1a4f..f565b438fd3 100644 --- a/tests/ut/python/parallel/test_range.py +++ b/tests/ut/python/parallel/test_range.py @@ -15,7 +15,7 @@ import numpy as np import mindspore as ms -import mindspore.nn as nn +from mindspore.common import dtype as mstype from mindspore import context, Tensor, Parameter from mindspore.nn import Cell, Momentum from mindspore.ops import operations as P @@ -48,18 +48,25 @@ class Net(Cell): def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None): super().__init__() self.mul = P.Mul().shard(strategy1) - self.range = nn.Range(start, limit, delta) - self.range.range_x.shard(strategy2) + if isinstance(start, float): + self.type = mstype.float32 + else: + self.type = mstype.int32 + self.start = Tensor(start, self.type) + self.limit = Tensor(limit, self.type) + self.delta = Tensor(delta, self.type) + self.range = P.Range() + self.range.shard(strategy2) self.mul2 = P.Mul().shard(strategy3) self.weight = Parameter(weight, "w") - def construct(self, x, b): - r_out = self.range() + r_out = self.range(self.start, self.limit, self.delta) out = self.mul(x, self.weight) out = self.mul2(out, r_out) return out + dev_num = 4 _x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32) _b = Tensor(np.ones([8]), dtype=ms.float32) @@ -98,5 +105,5 @@ def test_range2(): def test_range3(): context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2) - net = Net(_w1, 4.0, None, 0.5) + net = Net(_w1, 0.0, 4.0, 0.5) compile_net(net)