Add RangeD for GE

This commit is contained in:
liuxiao 2020-05-19 16:37:50 +08:00
parent c8f69f5db2
commit 627724a205
8 changed files with 132 additions and 2 deletions

View File

@ -199,6 +199,7 @@ const char kNameApplyRMSProp[] = "ApplyRMSProp";
const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";
const char kNameL2Loss[] = "L2Loss";
const char kNameCTCLoss[] = "CTCLoss";
const char kNameRange[] = "Range";
const char kNameSquareSumAll[] = "SquareSumAll";
// -----------------OpAdapter initialization--------------
@ -400,6 +401,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)},
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)},
{string(kNameRange), ADPT_DESC(RangeD)},
{string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}};
#ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);

View File

@ -842,6 +842,13 @@ ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int>())},
{"output_num", ATTR_DESC(num_split, AnyTraits<int>())}};
DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}};
// Range
INPUT_MAP(RangeD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(RangeD) = {{"start", ATTR_DESC(start, AnyTraits<float>())},
{"limit", ATTR_DESC(limit, AnyTraits<float>())},
{"delta", ATTR_DESC(delta, AnyTraits<float>())}};
OUTPUT_MAP(RangeD) = {{0, OUTPUT_DESC(y)}};
// Neg
INPUT_MAP(Neg) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Neg) = EMPTY_ATTR_MAP;

View File

@ -376,6 +376,8 @@ DECLARE_OP_USE_OUTPUT(OneHot)
DECLARE_OP_ADAPTER(GatherV2D)
DECLARE_OP_USE_INPUT_ATTR(GatherV2D)
DECLARE_OP_USE_OUTPUT(GatherV2D)
DECLARE_OP_ADAPTER(RangeD)
DECLARE_OP_USE_OUTPUT(RangeD)
DECLARE_OP_ADAPTER(Data)
DECLARE_OP_ADAPTER(BiasAdd)

View File

@ -13,11 +13,14 @@
# limitations under the License.
# ============================================================================
"""math"""
import math
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from ..cell import Cell
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp']
__all__ = ['ReduceLogSumExp', 'Range']
class ReduceLogSumExp(Cell):
r"""
@ -66,3 +69,56 @@ class ReduceLogSumExp(Cell):
sumexp = self.sum(exp, self.axis)
logsumexp = self.log(sumexp)
return logsumexp
class Range(Cell):
r"""
Creates a sequence of numbers.
Args:
start (Union[int, float]): If `limit` is `None`, the value acts as limit in the range and first entry
defaults to `0`. Otherwise, it acts as first entry in the range.
limit (Union[int, float]): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
while set the first entry of the range to `0`.
delta (Union[int, float]): Increment of the range. Default: 1.
Outputs:
Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.
Examples:
>>> net = nn.Range(1, 8, 2)
>>> out = net()
[1, 3, 5, 7]
"""
def __init__(self, start, limit=None, delta=1):
super(Range, self).__init__()
validator.check_value_type("start", start, [int, float], None)
validator.check_value_type("delta", delta, [int, float], None)
if limit is not None:
validator.check_value_type("limit", limit, [int, float], None)
if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int):
self.dtype = mstype.int32
else:
self.dtype = mstype.float32
else:
if isinstance(start, int) and isinstance(delta, int):
self.dtype = mstype.int32
else:
self.dtype = mstype.float32
if isinstance(start, int):
start = float(start)
if isinstance(limit, int):
limit = float(limit)
if isinstance(delta, int):
delta = float(delta)
self.range_x = P.Range(start, limit, delta)
if limit is None:
length_input = math.ceil(start / delta)
else:
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
def construct(self):
range_out = self.range_x(self.input_tensor)
return range_out

View File

@ -268,6 +268,15 @@ def get_bprop_gather_v2(self):
return bprop
@bprop_getters.register(P.Range)
def get_bprop_range(self):
"""Generate bprop for Range"""
def bprop(x, out, dout):
return (zeros_like(x),)
return bprop
@bprop_getters.register(P.Pack)
def get_bprop_pack(self):
"""Generate bprop for Pack"""

View File

@ -23,7 +23,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range,
SameTypeShape, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
@ -125,6 +125,7 @@ __all__ = [
'StridedSlice',
'ReduceSum',
'ReduceMean',
'Range',
'LayerNorm',
'Rank',
'Less',

View File

@ -528,6 +528,55 @@ class GatherV2(PrimitiveWithInfer):
return out
class Range(PrimitiveWithInfer):
r"""
Creates a sequence of numbers.
Set `input_x` as :math:`x_i` for each element, `output` as follows:
.. math::
\text{output}(x_i) = x_i * \text{delta} + \text{start}
Args:
start (float): If `limit` is `None`, the value acts as limit in the range and first entry
defaults to `0`. Otherwise, it acts as first entry in the range.
limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
while set the first entry of the range to `0`.
delta (float): Increment of the range. Default: 1.0.
Inputs:
- **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32.
Outputs:
Tensor, has the same shape and dtype as `input_x`.
Examples:
>>> range = P.Range(1.0, 8.0, 2.0)
>>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
>>> range(x)
[3, 5, 7, 5]
"""
@prim_attr_register
def __init__(self, start, limit=None, delta=1.0):
self.init_prim_io_names(inputs=['x'], outputs=['y'])
self.delta = validator.check_value_type("delta", delta, [float], self.name)
validator.check_value_type("start", start, [float], self.name)
if limit is None:
self.start = 0.0
self.limit = start
self.add_prim_attr("start", self.start)
self.add_prim_attr("limit", self.limit)
else:
validator.check_value_type("limit", limit, [float], self.name)
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
return x_dtype
class Split(PrimitiveWithInfer):
"""
Splits input tensor into output_num of tensors along the given axis and output numbers.

View File

@ -784,6 +784,10 @@ test_case_nn_ops = [
'desc_const': [0],
'desc_inputs': [[1152], Tensor(np.array(10).astype(np.int32))],
'desc_bprop': [Tensor(np.array(10).astype(np.float32))]}),
('Range', {
'block': P.Range(1.0, 5.0),
'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))],
'desc_bprop': [[10]]}),
('UnsortedSegmentSum', {
'block': P.UnsortedSegmentSum(),
'desc_const': [1280],