From: @lianliguang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-16 16:56:53 +08:00 committed by Gitee
commit f9c408934e
4 changed files with 36 additions and 38 deletions

View File

@ -13,10 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""math""" """math"""
import math
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common._decorator import deprecated from mindspore.common._decorator import deprecated
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
@ -25,7 +23,6 @@ from ..cell import Cell
from ...common import dtype as mstype from ...common import dtype as mstype
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp', __all__ = ['ReduceLogSumExp',
'Range', 'Range',
'LGamma', 'LGamma',
@ -140,37 +137,15 @@ class Range(Cell):
def __init__(self, start, limit=None, delta=1): def __init__(self, start, limit=None, delta=1):
super(Range, self).__init__() super(Range, self).__init__()
validator.check_value_type("start", start, [int, float], self.cls_name) data = np.arange(start, limit, delta)
validator.check_value_type("delta", delta, [int, float], self.cls_name) if data.dtype == np.float:
if delta == 0: self.ms_dtype = mstype.float32
raise ValueError("The input of `delta` can not be equal to zero.")
if limit is not None:
validator.check_value_type("limit", limit, [int, float], self.cls_name)
if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int):
self.dtype = mstype.int32
else:
self.dtype = mstype.float32
else: else:
if isinstance(start, int) and isinstance(delta, int): self.ms_dtype = mstype.int32
self.dtype = mstype.int32 self.result_tensor = Tensor(data, dtype=self.ms_dtype)
else:
self.dtype = mstype.float32
if isinstance(start, int):
start = float(start)
if isinstance(limit, int):
limit = float(limit)
if isinstance(delta, int):
delta = float(delta)
self.range_x = inner.Range(start, limit, delta)
if limit is None:
length_input = math.ceil(start / delta)
else:
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
def construct(self): def construct(self):
range_out = self.range_x(self.input_tensor) return self.result_tensor
return range_out
class LGamma(Cell): class LGamma(Cell):

View File

@ -16,6 +16,7 @@
"""Inner operators.""" """Inner operators."""
import numpy as np import numpy as np
from mindspore.common import Tensor
from ..._checkparam import Rel from ..._checkparam import Rel
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ... import context from ... import context
@ -25,6 +26,7 @@ from ..operations.math_ops import _infer_shape_reduce
from ...communication.management import GlobalComm from ...communication.management import GlobalComm
from .. import signature as sig from .. import signature as sig
class ExtractImagePatches(PrimitiveWithInfer): class ExtractImagePatches(PrimitiveWithInfer):
""" """
Extracts patches from images. Extracts patches from images.
@ -164,6 +166,9 @@ class Range(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
return x_dtype return x_dtype
def infer_value(self, x_value):
return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype)
class Quant(PrimitiveWithInfer): class Quant(PrimitiveWithInfer):
r""" r"""
@ -408,6 +413,7 @@ class Send(PrimitiveWithInfer):
>>> net = Net() >>> net = Net()
>>> output = net(input_) >>> output = net(input_)
""" """
@prim_attr_register @prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = dest_rank self.rank = dest_rank
@ -464,6 +470,7 @@ class Receive(PrimitiveWithInfer):
>>> net = Net() >>> net = Net()
>>> output = net() >>> output = net()
""" """
@prim_attr_register @prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = src_rank self.rank = src_rank

View File

@ -2391,6 +2391,7 @@ class Pack(PrimitiveWithInfer):
Same as operator Stack. Pack will be deprecated in the future. Same as operator Stack. Pack will be deprecated in the future.
Please use Stack instead. Please use Stack instead.
""" """
@deprecated("1.1", "Stack", True) @deprecated("1.1", "Stack", True)
@prim_attr_register @prim_attr_register
def __init__(self, axis=0): def __init__(self, axis=0):
@ -2469,6 +2470,7 @@ class Unpack(PrimitiveWithInfer):
Same as operator Unstack. Unpack will be deprecated in the future. Same as operator Unstack. Unpack will be deprecated in the future.
Please use Unstack instead. Please use Unstack instead.
""" """
@deprecated("1.1", "Unstack", True) @deprecated("1.1", "Unstack", True)
@prim_attr_register @prim_attr_register
def __init__(self, axis=0): def __init__(self, axis=0):
@ -3491,7 +3493,6 @@ class ScatterUpdate(_ScatterOp_Dynamic):
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
class ScatterNdUpdate(_ScatterNdOp): class ScatterNdUpdate(_ScatterNdOp):
r""" r"""
Updates tensor values by using input indices and value. Updates tensor values by using input indices and value.
@ -5250,3 +5251,11 @@ class Range(PrimitiveWithCheck):
valid_dtypes = [mstype.int32, mstype.float32] valid_dtypes = [mstype.int32, mstype.float32]
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype} inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name) validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
def infer_value(self, start_value, limit_value, delat_value):
if start_value is not None and limit_value is not None and delat_value is not None:
start = np.asscalar(start_value.asnumpy())
limit = np.asscalar(limit_value.asnumpy())
delat = np.asscalar(delat_value.asnumpy())
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
return None

View File

@ -15,7 +15,7 @@
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.nn as nn from mindspore.common import dtype as mstype
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum from mindspore.nn import Cell, Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -48,18 +48,25 @@ class Net(Cell):
def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None): def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None):
super().__init__() super().__init__()
self.mul = P.Mul().shard(strategy1) self.mul = P.Mul().shard(strategy1)
self.range = nn.Range(start, limit, delta) if isinstance(start, float):
self.range.range_x.shard(strategy2) self.type = mstype.float32
else:
self.type = mstype.int32
self.start = Tensor(start, self.type)
self.limit = Tensor(limit, self.type)
self.delta = Tensor(delta, self.type)
self.range = P.Range()
self.range.shard(strategy2)
self.mul2 = P.Mul().shard(strategy3) self.mul2 = P.Mul().shard(strategy3)
self.weight = Parameter(weight, "w") self.weight = Parameter(weight, "w")
def construct(self, x, b): def construct(self, x, b):
r_out = self.range() r_out = self.range(self.start, self.limit, self.delta)
out = self.mul(x, self.weight) out = self.mul(x, self.weight)
out = self.mul2(out, r_out) out = self.mul2(out, r_out)
return out return out
dev_num = 4 dev_num = 4
_x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32) _x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32)
_b = Tensor(np.ones([8]), dtype=ms.float32) _b = Tensor(np.ones([8]), dtype=ms.float32)
@ -98,5 +105,5 @@ def test_range2():
def test_range3(): def test_range3():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2) context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2)
net = Net(_w1, 4.0, None, 0.5) net = Net(_w1, 0.0, 4.0, 0.5)
compile_net(net) compile_net(net)