forked from mindspore-Ecosystem/mindspore
unify range ops
This commit is contained in:
parent
60922a1d65
commit
17b9758543
|
@ -13,10 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""math"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common._decorator import deprecated
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -25,7 +23,6 @@ from ..cell import Cell
|
|||
from ...common import dtype as mstype
|
||||
from ..._checkparam import Validator as validator
|
||||
|
||||
|
||||
__all__ = ['ReduceLogSumExp',
|
||||
'Range',
|
||||
'LGamma',
|
||||
|
@ -140,37 +137,15 @@ class Range(Cell):
|
|||
|
||||
def __init__(self, start, limit=None, delta=1):
|
||||
super(Range, self).__init__()
|
||||
validator.check_value_type("start", start, [int, float], self.cls_name)
|
||||
validator.check_value_type("delta", delta, [int, float], self.cls_name)
|
||||
if delta == 0:
|
||||
raise ValueError("The input of `delta` can not be equal to zero.")
|
||||
if limit is not None:
|
||||
validator.check_value_type("limit", limit, [int, float], self.cls_name)
|
||||
if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int):
|
||||
self.dtype = mstype.int32
|
||||
else:
|
||||
self.dtype = mstype.float32
|
||||
data = np.arange(start, limit, delta)
|
||||
if data.dtype == np.float:
|
||||
self.ms_dtype = mstype.float32
|
||||
else:
|
||||
if isinstance(start, int) and isinstance(delta, int):
|
||||
self.dtype = mstype.int32
|
||||
else:
|
||||
self.dtype = mstype.float32
|
||||
if isinstance(start, int):
|
||||
start = float(start)
|
||||
if isinstance(limit, int):
|
||||
limit = float(limit)
|
||||
if isinstance(delta, int):
|
||||
delta = float(delta)
|
||||
self.range_x = inner.Range(start, limit, delta)
|
||||
if limit is None:
|
||||
length_input = math.ceil(start / delta)
|
||||
else:
|
||||
length_input = math.ceil((limit - start) / delta)
|
||||
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
|
||||
self.ms_dtype = mstype.int32
|
||||
self.result_tensor = Tensor(data, dtype=self.ms_dtype)
|
||||
|
||||
def construct(self):
|
||||
range_out = self.range_x(self.input_tensor)
|
||||
return range_out
|
||||
return self.result_tensor
|
||||
|
||||
|
||||
class LGamma(Cell):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Inner operators."""
|
||||
|
||||
import numpy as np
|
||||
from mindspore.common import Tensor
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ... import context
|
||||
|
@ -25,6 +26,7 @@ from ..operations.math_ops import _infer_shape_reduce
|
|||
from ...communication.management import GlobalComm
|
||||
from .. import signature as sig
|
||||
|
||||
|
||||
class ExtractImagePatches(PrimitiveWithInfer):
|
||||
"""
|
||||
Extracts patches from images.
|
||||
|
@ -164,6 +166,9 @@ class Range(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_value(self, x_value):
|
||||
return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype)
|
||||
|
||||
|
||||
class Quant(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
@ -408,6 +413,7 @@ class Send(PrimitiveWithInfer):
|
|||
>>> net = Net()
|
||||
>>> output = net(input_)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.rank = dest_rank
|
||||
|
@ -464,6 +470,7 @@ class Receive(PrimitiveWithInfer):
|
|||
>>> net = Net()
|
||||
>>> output = net()
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.rank = src_rank
|
||||
|
|
|
@ -2391,6 +2391,7 @@ class Pack(PrimitiveWithInfer):
|
|||
Same as operator Stack. Pack will be deprecated in the future.
|
||||
Please use Stack instead.
|
||||
"""
|
||||
|
||||
@deprecated("1.1", "Stack", True)
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=0):
|
||||
|
@ -2469,6 +2470,7 @@ class Unpack(PrimitiveWithInfer):
|
|||
Same as operator Unstack. Unpack will be deprecated in the future.
|
||||
Please use Unstack instead.
|
||||
"""
|
||||
|
||||
@deprecated("1.1", "Unstack", True)
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=0):
|
||||
|
@ -3491,7 +3493,6 @@ class ScatterUpdate(_ScatterOp_Dynamic):
|
|||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
|
||||
class ScatterNdUpdate(_ScatterNdOp):
|
||||
r"""
|
||||
Updates tensor values by using input indices and value.
|
||||
|
@ -5250,3 +5251,11 @@ class Range(PrimitiveWithCheck):
|
|||
valid_dtypes = [mstype.int32, mstype.float32]
|
||||
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
|
||||
|
||||
def infer_value(self, start_value, limit_value, delat_value):
|
||||
if start_value is not None and limit_value is not None and delat_value is not None:
|
||||
start = np.asscalar(start_value.asnumpy())
|
||||
limit = np.asscalar(limit_value.asnumpy())
|
||||
delat = np.asscalar(delat_value.asnumpy())
|
||||
return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
|
||||
return None
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -48,18 +48,25 @@ class Net(Cell):
|
|||
def __init__(self, weight, start, limit, delta, strategy1=None, strategy2=None, strategy3=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.range = nn.Range(start, limit, delta)
|
||||
self.range.range_x.shard(strategy2)
|
||||
if isinstance(start, float):
|
||||
self.type = mstype.float32
|
||||
else:
|
||||
self.type = mstype.int32
|
||||
self.start = Tensor(start, self.type)
|
||||
self.limit = Tensor(limit, self.type)
|
||||
self.delta = Tensor(delta, self.type)
|
||||
self.range = P.Range()
|
||||
self.range.shard(strategy2)
|
||||
self.mul2 = P.Mul().shard(strategy3)
|
||||
self.weight = Parameter(weight, "w")
|
||||
|
||||
|
||||
def construct(self, x, b):
|
||||
r_out = self.range()
|
||||
r_out = self.range(self.start, self.limit, self.delta)
|
||||
out = self.mul(x, self.weight)
|
||||
out = self.mul2(out, r_out)
|
||||
return out
|
||||
|
||||
|
||||
dev_num = 4
|
||||
_x = Tensor(np.ones([64 // dev_num, 8]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([8]), dtype=ms.float32)
|
||||
|
@ -98,5 +105,5 @@ def test_range2():
|
|||
|
||||
def test_range3():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=dev_num, global_rank=2)
|
||||
net = Net(_w1, 4.0, None, 0.5)
|
||||
net = Net(_w1, 0.0, 4.0, 0.5)
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue