forked from mindspore-Ecosystem/mindspore
!13380 unify range
From: @lianliguang Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
f9c408934e
|
@ -13,10 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""math"""
|
"""math"""
|
||||||
import math
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import _inner_ops as inner
|
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common._decorator import deprecated
|
from mindspore.common._decorator import deprecated
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
|
@ -25,7 +23,6 @@ from ..cell import Cell
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ReduceLogSumExp',
|
__all__ = ['ReduceLogSumExp',
|
||||||
'Range',
|
'Range',
|
||||||
'LGamma',
|
'LGamma',
|
||||||
|
@ -140,37 +137,15 @@ class Range(Cell):
|
||||||
|
|
||||||
def __init__(self, start, limit=None, delta=1):
|
def __init__(self, start, limit=None, delta=1):
|
||||||
super(Range, self).__init__()
|
super(Range, self).__init__()
|
||||||
validator.check_value_type("start", start, [int, float], self.cls_name)
|
data = np.arange(start, limit, delta)
|
||||||
validator.check_value_type("delta", delta, [int, float], self.cls_name)
|
if data.dtype == np.float:
|
||||||
if delta == 0:
|
self.ms_dtype = mstype.float32
|
||||||
raise ValueError("The input of `delta` can not be equal to zero.")
|
|
||||||
if limit is not None:
|
|
||||||
validator.check_value_type("limit", limit, [int, float], self.cls_name)
|
|
||||||
if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int):
|
|
||||||
self.dtype = mstype.int32
|
|
||||||
else:
|
|
||||||
self.dtype = mstype.float32
|
|
||||||
else:
|
else:
|
||||||
if isinstance(start, int) and isinstance(delta, int):
|
self.ms_dtype = mstype.int32
|
||||||
self.dtype = mstype.int32
|
self.result_tensor = Tensor(data, dtype=self.ms_dtype)
|
||||||
else:
|
|
||||||
self.dtype = mstype.float32
|
|
||||||
if isinstance(start, int):
|
|
||||||
start = float(start)
|
|
||||||
if isinstance(limit, int):
|
|
||||||
limit = float(limit)
|
|
||||||
if isinstance(delta, int):
|
|
||||||
delta = float(delta)
|
|
||||||
self.range_x = inner.Range(start, limit, delta)
|
|
||||||
if limit is None:
|
|
||||||
length_input = math.ceil(start / delta)
|
|
||||||
else:
|
|
||||||
length_input = math.ceil((limit - start) / delta)
|
|
||||||
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
|
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
range_out = self.range_x(self.input_tensor)
|
return self.result_tensor
|
||||||
return range_out
|
|
||||||
|
|
||||||
|
|
||||||
class LGamma(Cell):
|
class LGamma(Cell):
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
"""Inner operators."""
|
"""Inner operators."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mindspore.common import Tensor
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ... import context
|
from ... import context
|
||||||
|
@ -25,6 +26,7 @@ from ..operations.math_ops import _infer_shape_reduce
|
||||||
from ...communication.management import GlobalComm
|
from ...communication.management import GlobalComm
|
||||||
from .. import signature as sig
|
from .. import signature as sig
|
||||||
|
|
||||||
|
|
||||||
class ExtractImagePatches(PrimitiveWithInfer):
|
class ExtractImagePatches(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Extracts patches from images.
|
Extracts patches from images.
|
||||||
|
@ -164,6 +166,9 @@ class Range(PrimitiveWithInfer):
|
||||||
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
|
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
def infer_value(self, x_value):
|
||||||
|
return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype)
|
||||||
|
|
||||||
|
|
||||||
class Quant(PrimitiveWithInfer):
|
class Quant(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
|
@ -408,6 +413,7 @@ class Send(PrimitiveWithInfer):
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> output = net(input_)
|
>>> output = net(input_)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
self.rank = dest_rank
|
self.rank = dest_rank
|
||||||
|
@ -464,6 +470,7 @@ class Receive(PrimitiveWithInfer):
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> output = net()
|
>>> output = net()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
self.rank = src_rank
|
self.rank = src_rank
|
||||||
|
|
|
@ -2391,6 +2391,7 @@ class Pack(PrimitiveWithInfer):
|
||||||
Same as operator Stack. Pack will be deprecated in the future.
|
Same as operator Stack. Pack will be deprecated in the future.
|
||||||
Please use Stack instead.
|
Please use Stack instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@deprecated("1.1", "Stack", True)
|
@deprecated("1.1", "Stack", True)
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, axis=0):
|
def __init__(self, axis=0):
|
||||||
|
@ -2469,6 +2470,7 @@ class Unpack(PrimitiveWithInfer):
|
||||||
Same as operator Unstack. Unpack will be deprecated in the future.
|
Same as operator Unstack. Unpack will be deprecated in the future.
|
||||||
Please use Unstack instead.
|
Please use Unstack instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@deprecated("1.1", "Unstack", True)
|
@deprecated("1.1", "Unstack", True)
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, axis=0):
|
def __init__(self, axis=0):
|
||||||
|
@ -3491,7 +3493,6 @@ class ScatterUpdate(_ScatterOp_Dynamic):
|
||||||
self.add_prim_attr('side_effect_mem', True)
|
self.add_prim_attr('side_effect_mem', True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScatterNdUpdate(_ScatterNdOp):
|
class ScatterNdUpdate(_ScatterNdOp):
|
||||||
r"""
|
r"""
|
||||||
Updates tensor values by using input indices and value.
|
Updates tensor values by using input indices and value.
|
||||||
|
@ -5250,3 +5251,11 @@ class Range(PrimitiveWithCheck):
|
||||||
valid_dtypes = [mstype.int32, mstype.float32]
|
valid_dtypes = [mstype.int32, mstype.float32]
|
||||||
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
|
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
|
||||||
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
|
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
|
||||||
|
|
||||||
|
def infer_value(self, start_value, limit_value, delat_value):
|
||||||
|
if start_value is not None and limit_value is not None and delat_value is not None:
|
||||||
|
start = np.asscalar(start_value.asnumpy())
|
||||||
|
limit = np.asscalar(limit_value.asnumpy())
|
||||||
|
delat = np.asscalar(delat_value.asnumpy())
|
||||||
|
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
|
||||||
|
return None
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore import context, Tensor, Parameter
|
from mindspore import context, Tensor, Parameter
|
||||||
from mindspore.nn import Cell, Momentum
|
from mindspore.nn import Cell, Momentum
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
@ -48,18 +48,25 @@ class Net(Cell):
|
||||||
def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None):
|
def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mul = P.Mul().shard(strategy1)
|
self.mul = P.Mul().shard(strategy1)
|
||||||
self.range = nn.Range(start, limit, delta)
|
if isinstance(start, float):
|
||||||
self.range.range_x.shard(strategy2)
|
self.type = mstype.float32
|
||||||
|
else:
|
||||||
|
self.type = mstype.int32
|
||||||
|
self.start = Tensor(start, self.type)
|
||||||
|
self.limit = Tensor(limit, self.type)
|
||||||
|
self.delta = Tensor(delta, self.type)
|
||||||
|
self.range = P.Range()
|
||||||
|
self.range.shard(strategy2)
|
||||||
self.mul2 = P.Mul().shard(strategy3)
|
self.mul2 = P.Mul().shard(strategy3)
|
||||||
self.weight = Parameter(weight, "w")
|
self.weight = Parameter(weight, "w")
|
||||||
|
|
||||||
|
|
||||||
def construct(self, x, b):
|
def construct(self, x, b):
|
||||||
r_out = self.range()
|
r_out = self.range(self.start, self.limit, self.delta)
|
||||||
out = self.mul(x, self.weight)
|
out = self.mul(x, self.weight)
|
||||||
out = self.mul2(out, r_out)
|
out = self.mul2(out, r_out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
dev_num = 4
|
dev_num = 4
|
||||||
_x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32)
|
_x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32)
|
||||||
_b = Tensor(np.ones([8]), dtype=ms.float32)
|
_b = Tensor(np.ones([8]), dtype=ms.float32)
|
||||||
|
@ -98,5 +105,5 @@ def test_range2():
|
||||||
|
|
||||||
def test_range3():
|
def test_range3():
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2)
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2)
|
||||||
net = Net(_w1, 4.0, None, 0.5)
|
net = Net(_w1, 0.0, 4.0, 0.5)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
Loading…
Reference in New Issue