forked from mindspore-Ecosystem/mindspore
vm for lin_space
This commit is contained in:
parent
0327d7e79b
commit
e71599b5ca
|
@ -112,6 +112,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"square_sum_all", "square_sum_all"},
|
||||
{"cum_sum", "cumsum_d"},
|
||||
{"range", "range_d"},
|
||||
{"lin_space", "lin_space_d"},
|
||||
{"inv_grad", "inv_grad"},
|
||||
{"apply_rms_prop", "apply_rms_prop_d"},
|
||||
{"cum_prod", "cumprod_d"},
|
||||
|
|
|
@ -20,8 +20,11 @@ from mindspore.common.tensor import Tensor
|
|||
from ..cell import Cell
|
||||
from ...common import dtype as mstype
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
|
||||
|
||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace']
|
||||
|
||||
__all__ = ['ReduceLogSumExp', 'Range']
|
||||
|
||||
class ReduceLogSumExp(Cell):
|
||||
r"""
|
||||
|
@ -125,3 +128,48 @@ class Range(Cell):
|
|||
def construct(self):
|
||||
range_out = self.range_x(self.input_tensor)
|
||||
return range_out
|
||||
|
||||
|
||||
class LinSpace(Cell):
|
||||
r"""
|
||||
Generates values in an interval. And return the corresponding interpolation accroding to assist.
|
||||
|
||||
Args:
|
||||
- **start** (Union[int, float]) - The start of interval, With shape of 0-D.
|
||||
- **stop** (Union[int, float]) - The end of interval, With shape of 0-D.
|
||||
- **num** (int) - ticks number in the interval, the ticks include start and stop value.
|
||||
With shape of 0-D.
|
||||
|
||||
Outputs:
|
||||
Tensor, With type same as `start`. The shape is 1-D with length of `num`.
|
||||
|
||||
Examples:
|
||||
>>> linspace = nn.LinSpace()
|
||||
>>> start = Tensor(1, mindspore.float32)
|
||||
>>> stop = Tensor(10, mindspore.float32)
|
||||
>>> num = Tensor(5, mindspore.int32)
|
||||
>>> output = linspace(start, stop, num)
|
||||
[1, 3.25, 5.5, 7.75, 10]
|
||||
"""
|
||||
|
||||
def __init__(self, start, stop, num):
|
||||
super(LinSpace, self).__init__()
|
||||
validator.check_value_type("start", start, [int, float], self.cls_name)
|
||||
validator.check_value_type("stop", stop, [int, float], self.cls_name)
|
||||
validator.check_value_type("num", num, [int], self.cls_name)
|
||||
validator.check_integer("num", num, 0, Rel.GT, self.cls_name)
|
||||
|
||||
self.is_single = bool(num == 1)
|
||||
self.lin_space = inner.LinSpace()
|
||||
self.start = Tensor(start, mstype.float32)
|
||||
self.stop = Tensor(stop, mstype.float32)
|
||||
self.assist = Tensor(list(range(num)), mstype.float32)
|
||||
self.num = Tensor(num, mstype.int32)
|
||||
self.start_array = Tensor([start], mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
if self.is_single:
|
||||
return self.start_array
|
||||
|
||||
lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num)
|
||||
return lin_space_out
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.ops import _selected_grad_ops as SG
|
|||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..operations import _grad_ops as G
|
||||
from ..operations import _inner_ops as inner
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
||||
from .grad_base import bprop_getters
|
||||
|
@ -1049,3 +1050,13 @@ def get_bprop_inv(self):
|
|||
dx = inv_grad(out, dout)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.LinSpace)
|
||||
def get_bprop_lin_space(self):
|
||||
"""Grad definition for `LinSpace` operation."""
|
||||
|
||||
def bprop(assist, start, stop, num, out, dout):
|
||||
return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -262,3 +262,4 @@ from .tensor_scatter_update import _tensor_scatter_update_tbe
|
|||
from .inplace_update import _inplace_update_tbe
|
||||
from .splitv import _split_v_tbe
|
||||
from .in_top_k import _in_top_k_tbe
|
||||
from .lin_space import _lin_space_tbe
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LinSpace op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lin_space_op_info = TBERegOp("LinSpace") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lin_space.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lin_space") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("broadcast") \
|
||||
.input(0, "assist", False, "required", "all") \
|
||||
.input(1, "start", False, "required", "all") \
|
||||
.input(2, "stop", False, "required", "all") \
|
||||
.input(3, "num", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default,) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lin_space_op_info)
|
||||
def _lin_space_tbe():
|
||||
"""LinSpace TBE register"""
|
||||
return
|
|
@ -328,3 +328,42 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|||
'dtype': params['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class LinSpace(PrimitiveWithInfer):
|
||||
r"""
|
||||
Generates values in an interval. And return the corresponding interpolation accroding to assist.
|
||||
|
||||
Inputs:
|
||||
- **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D.
|
||||
- **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
|
||||
- **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
|
||||
- **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value.
|
||||
With shape of 0-D.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `assist`.
|
||||
|
||||
Examples:
|
||||
>>> linspace = P.LinSpace()
|
||||
>>> assist = Tensor([5, 5.5], mindspore.float32)
|
||||
>>> start = Tensor(1, mindspore.float32)
|
||||
>>> stop = Tensor(10, mindspore.float32)
|
||||
>>> num = Tensor(5, mindspore.int32)
|
||||
>>> output = linspace(assist, start, stop, num)
|
||||
[12.25, 13.375]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, assist, start, stop, num):
|
||||
return assist
|
||||
|
||||
def infer_dtype(self, assist, start, stop, num):
|
||||
args = {"num": num}
|
||||
validator.check_tensor_type_same(args, (mstype.int32,), self.name)
|
||||
args = {"assist": assist, "start": start, "stop": stop}
|
||||
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
|
||||
return assist
|
||||
|
|
|
@ -1599,6 +1599,14 @@ test_case_array_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
|
||||
Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
||||
'desc_bprop': [[3, 3]]}),
|
||||
('LinSpace', {
|
||||
'block': inner.LinSpace(),
|
||||
'desc_inputs': [Tensor([5, 5.5], mstype.float32),
|
||||
Tensor(1, mstype.float32),
|
||||
Tensor(10, mstype.float32),
|
||||
Tensor(5, mstype.int32)],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
|
|
Loading…
Reference in New Issue