vm for lin_space

This commit is contained in:
jiangjinsheng 2020-06-16 16:44:34 +08:00
parent 0327d7e79b
commit e71599b5ca
7 changed files with 149 additions and 1 deletions

View File

@ -112,6 +112,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"square_sum_all", "square_sum_all"},
{"cum_sum", "cumsum_d"},
{"range", "range_d"},
{"lin_space", "lin_space_d"},
{"inv_grad", "inv_grad"},
{"apply_rms_prop", "apply_rms_prop_d"},
{"cum_prod", "cumprod_d"},

View File

@ -20,8 +20,11 @@ from mindspore.common.tensor import Tensor
from ..cell import Cell
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace']
__all__ = ['ReduceLogSumExp', 'Range']
class ReduceLogSumExp(Cell):
r"""
@ -125,3 +128,48 @@ class Range(Cell):
def construct(self):
range_out = self.range_x(self.input_tensor)
return range_out
class LinSpace(Cell):
r"""
Generates values in an interval. And return the corresponding interpolation accroding to assist.
Args:
- **start** (Union[int, float]) - The start of interval, With shape of 0-D.
- **stop** (Union[int, float]) - The end of interval, With shape of 0-D.
- **num** (int) - ticks number in the interval, the ticks include start and stop value.
With shape of 0-D.
Outputs:
Tensor, With type same as `start`. The shape is 1-D with length of `num`.
Examples:
>>> linspace = nn.LinSpace()
>>> start = Tensor(1, mindspore.float32)
>>> stop = Tensor(10, mindspore.float32)
>>> num = Tensor(5, mindspore.int32)
>>> output = linspace(start, stop, num)
[1, 3.25, 5.5, 7.75, 10]
"""
def __init__(self, start, stop, num):
super(LinSpace, self).__init__()
validator.check_value_type("start", start, [int, float], self.cls_name)
validator.check_value_type("stop", stop, [int, float], self.cls_name)
validator.check_value_type("num", num, [int], self.cls_name)
validator.check_integer("num", num, 0, Rel.GT, self.cls_name)
self.is_single = bool(num == 1)
self.lin_space = inner.LinSpace()
self.start = Tensor(start, mstype.float32)
self.stop = Tensor(stop, mstype.float32)
self.assist = Tensor(list(range(num)), mstype.float32)
self.num = Tensor(num, mstype.int32)
self.start_array = Tensor([start], mstype.float32)
def construct(self):
if self.is_single:
return self.start_array
lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num)
return lin_space_out

View File

@ -21,6 +21,7 @@ from mindspore.ops import _selected_grad_ops as SG
from .. import functional as F
from .. import operations as P
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
from .grad_base import bprop_getters
@ -1049,3 +1050,13 @@ def get_bprop_inv(self):
dx = inv_grad(out, dout)
return (dx,)
return bprop
@bprop_getters.register(inner.LinSpace)
def get_bprop_lin_space(self):
"""Grad definition for `LinSpace` operation."""
def bprop(assist, start, stop, num, out, dout):
return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num)
return bprop

View File

@ -262,3 +262,4 @@ from .tensor_scatter_update import _tensor_scatter_update_tbe
from .inplace_update import _inplace_update_tbe
from .splitv import _split_v_tbe
from .in_top_k import _in_top_k_tbe
from .lin_space import _lin_space_tbe

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LinSpace op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
lin_space_op_info = TBERegOp("LinSpace") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("lin_space.so") \
.compute_cost(10) \
.kernel_name("lin_space") \
.partial_flag(True) \
.op_pattern("broadcast") \
.input(0, "assist", False, "required", "all") \
.input(1, "start", False, "required", "all") \
.input(2, "stop", False, "required", "all") \
.input(3, "num", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
DataType.F32_Default,) \
.get_op_info()
@op_info_register(lin_space_op_info)
def _lin_space_tbe():
"""LinSpace TBE register"""
return

View File

@ -328,3 +328,42 @@ class EmbeddingLookup(PrimitiveWithInfer):
'dtype': params['dtype'],
'value': None}
return out
class LinSpace(PrimitiveWithInfer):
r"""
Generates values in an interval. And return the corresponding interpolation accroding to assist.
Inputs:
- **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D.
- **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
- **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
- **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value.
With shape of 0-D.
Outputs:
Tensor, has the same shape as `assist`.
Examples:
>>> linspace = P.LinSpace()
>>> assist = Tensor([5, 5.5], mindspore.float32)
>>> start = Tensor(1, mindspore.float32)
>>> stop = Tensor(10, mindspore.float32)
>>> num = Tensor(5, mindspore.int32)
>>> output = linspace(assist, start, stop, num)
[12.25, 13.375]
"""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, assist, start, stop, num):
return assist
def infer_dtype(self, assist, start, stop, num):
args = {"num": num}
validator.check_tensor_type_same(args, (mstype.int32,), self.name)
args = {"assist": assist, "start": start, "stop": stop}
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
return assist

View File

@ -1599,6 +1599,14 @@ test_case_array_ops = [
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[3, 3]]}),
('LinSpace', {
'block': inner.LinSpace(),
'desc_inputs': [Tensor([5, 5.5], mstype.float32),
Tensor(1, mstype.float32),
Tensor(10, mstype.float32),
Tensor(5, mstype.int32)],
'skip': ['backward'],
}),
]
test_case_other_ops = [