unify range ops

This commit is contained in:
LianLiguang 2021-03-16 10:47:42 +08:00
parent 60922a1d65
commit 17b9758543
4 changed files with 36 additions and 38 deletions

View File

@ -13,10 +13,8 @@
# limitations under the License.
# ============================================================================
"""math"""
import math
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor
from mindspore.common._decorator import deprecated
from mindspore.ops.primitive import constexpr
@ -25,7 +23,6 @@ from ..cell import Cell
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp',
'Range',
'LGamma',
@ -140,37 +137,15 @@ class Range(Cell):
def __init__(self, start, limit=None, delta=1):
super(Range, self).__init__()
validator.check_value_type("start", start, [int, float], self.cls_name)
validator.check_value_type("delta", delta, [int, float], self.cls_name)
if delta == 0:
raise ValueError("The input of `delta` can not be equal to zero.")
if limit is not None:
validator.check_value_type("limit", limit, [int, float], self.cls_name)
if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int):
self.dtype = mstype.int32
else:
self.dtype = mstype.float32
data = np.arange(start, limit, delta)
if data.dtype == np.float:
self.ms_dtype = mstype.float32
else:
if isinstance(start, int) and isinstance(delta, int):
self.dtype = mstype.int32
else:
self.dtype = mstype.float32
if isinstance(start, int):
start = float(start)
if isinstance(limit, int):
limit = float(limit)
if isinstance(delta, int):
delta = float(delta)
self.range_x = inner.Range(start, limit, delta)
if limit is None:
length_input = math.ceil(start / delta)
else:
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
self.ms_dtype = mstype.int32
self.result_tensor = Tensor(data, dtype=self.ms_dtype)
def construct(self):
range_out = self.range_x(self.input_tensor)
return range_out
return self.result_tensor
class LGamma(Cell):

View File

@ -16,6 +16,7 @@
"""Inner operators."""
import numpy as np
from mindspore.common import Tensor
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ... import context
@ -25,6 +26,7 @@ from ..operations.math_ops import _infer_shape_reduce
from ...communication.management import GlobalComm
from .. import signature as sig
class ExtractImagePatches(PrimitiveWithInfer):
"""
Extracts patches from images.
@ -164,6 +166,9 @@ class Range(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
return x_dtype
def infer_value(self, x_value):
return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype)
class Quant(PrimitiveWithInfer):
r"""
@ -408,6 +413,7 @@ class Send(PrimitiveWithInfer):
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = dest_rank
@ -464,6 +470,7 @@ class Receive(PrimitiveWithInfer):
>>> net = Net()
>>> output = net()
"""
@prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = src_rank

View File

@ -2391,6 +2391,7 @@ class Pack(PrimitiveWithInfer):
Same as operator Stack. Pack will be deprecated in the future.
Please use Stack instead.
"""
@deprecated("1.1", "Stack", True)
@prim_attr_register
def __init__(self, axis=0):
@ -2469,6 +2470,7 @@ class Unpack(PrimitiveWithInfer):
Same as operator Unstack. Unpack will be deprecated in the future.
Please use Unstack instead.
"""
@deprecated("1.1", "Unstack", True)
@prim_attr_register
def __init__(self, axis=0):
@ -3491,7 +3493,6 @@ class ScatterUpdate(_ScatterOp_Dynamic):
self.add_prim_attr('side_effect_mem', True)
class ScatterNdUpdate(_ScatterNdOp):
r"""
Updates tensor values by using input indices and value.
@ -5250,3 +5251,11 @@ class Range(PrimitiveWithCheck):
valid_dtypes = [mstype.int32, mstype.float32]
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
def infer_value(self, start_value, limit_value, delat_value):
if start_value is not None and limit_value is not None and delat_value is not None:
start = np.asscalar(start_value.asnumpy())
limit = np.asscalar(limit_value.asnumpy())
delat = np.asscalar(delat_value.asnumpy())
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
return None

View File

@ -15,7 +15,7 @@
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum
from mindspore.ops import operations as P
@ -48,18 +48,25 @@ class Net(Cell):
def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.range = nn.Range(start, limit, delta)
self.range.range_x.shard(strategy2)
if isinstance(start, float):
self.type = mstype.float32
else:
self.type = mstype.int32
self.start = Tensor(start, self.type)
self.limit = Tensor(limit, self.type)
self.delta = Tensor(delta, self.type)
self.range = P.Range()
self.range.shard(strategy2)
self.mul2 = P.Mul().shard(strategy3)
self.weight = Parameter(weight, "w")
def construct(self, x, b):
r_out = self.range()
r_out = self.range(self.start, self.limit, self.delta)
out = self.mul(x, self.weight)
out = self.mul2(out, r_out)
return out
dev_num = 4
_x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32)
_b = Tensor(np.ones([8]), dtype=ms.float32)
@ -98,5 +105,5 @@ def test_range2():
def test_range3():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2)
net = Net(_w1, 4.0, None, 0.5)
net = Net(_w1, 0.0, 4.0, 0.5)
compile_net(net)