forked from mindspore-Ecosystem/mindspore
vm for lin_space
This commit is contained in:
parent
0327d7e79b
commit
e71599b5ca
|
@ -112,6 +112,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
||||||
{"square_sum_all", "square_sum_all"},
|
{"square_sum_all", "square_sum_all"},
|
||||||
{"cum_sum", "cumsum_d"},
|
{"cum_sum", "cumsum_d"},
|
||||||
{"range", "range_d"},
|
{"range", "range_d"},
|
||||||
|
{"lin_space", "lin_space_d"},
|
||||||
{"inv_grad", "inv_grad"},
|
{"inv_grad", "inv_grad"},
|
||||||
{"apply_rms_prop", "apply_rms_prop_d"},
|
{"apply_rms_prop", "apply_rms_prop_d"},
|
||||||
{"cum_prod", "cumprod_d"},
|
{"cum_prod", "cumprod_d"},
|
||||||
|
|
|
@ -20,8 +20,11 @@ from mindspore.common.tensor import Tensor
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
from ..._checkparam import Rel
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace']
|
||||||
|
|
||||||
__all__ = ['ReduceLogSumExp', 'Range']
|
|
||||||
|
|
||||||
class ReduceLogSumExp(Cell):
|
class ReduceLogSumExp(Cell):
|
||||||
r"""
|
r"""
|
||||||
|
@ -125,3 +128,48 @@ class Range(Cell):
|
||||||
def construct(self):
|
def construct(self):
|
||||||
range_out = self.range_x(self.input_tensor)
|
range_out = self.range_x(self.input_tensor)
|
||||||
return range_out
|
return range_out
|
||||||
|
|
||||||
|
|
||||||
|
class LinSpace(Cell):
|
||||||
|
r"""
|
||||||
|
Generates values in an interval. And return the corresponding interpolation accroding to assist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- **start** (Union[int, float]) - The start of interval, With shape of 0-D.
|
||||||
|
- **stop** (Union[int, float]) - The end of interval, With shape of 0-D.
|
||||||
|
- **num** (int) - ticks number in the interval, the ticks include start and stop value.
|
||||||
|
With shape of 0-D.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, With type same as `start`. The shape is 1-D with length of `num`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> linspace = nn.LinSpace()
|
||||||
|
>>> start = Tensor(1, mindspore.float32)
|
||||||
|
>>> stop = Tensor(10, mindspore.float32)
|
||||||
|
>>> num = Tensor(5, mindspore.int32)
|
||||||
|
>>> output = linspace(start, stop, num)
|
||||||
|
[1, 3.25, 5.5, 7.75, 10]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, start, stop, num):
|
||||||
|
super(LinSpace, self).__init__()
|
||||||
|
validator.check_value_type("start", start, [int, float], self.cls_name)
|
||||||
|
validator.check_value_type("stop", stop, [int, float], self.cls_name)
|
||||||
|
validator.check_value_type("num", num, [int], self.cls_name)
|
||||||
|
validator.check_integer("num", num, 0, Rel.GT, self.cls_name)
|
||||||
|
|
||||||
|
self.is_single = bool(num == 1)
|
||||||
|
self.lin_space = inner.LinSpace()
|
||||||
|
self.start = Tensor(start, mstype.float32)
|
||||||
|
self.stop = Tensor(stop, mstype.float32)
|
||||||
|
self.assist = Tensor(list(range(num)), mstype.float32)
|
||||||
|
self.num = Tensor(num, mstype.int32)
|
||||||
|
self.start_array = Tensor([start], mstype.float32)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
if self.is_single:
|
||||||
|
return self.start_array
|
||||||
|
|
||||||
|
lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num)
|
||||||
|
return lin_space_out
|
||||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.ops import _selected_grad_ops as SG
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
|
from ..operations import _inner_ops as inner
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters
|
||||||
|
@ -1049,3 +1050,13 @@ def get_bprop_inv(self):
|
||||||
dx = inv_grad(out, dout)
|
dx = inv_grad(out, dout)
|
||||||
return (dx,)
|
return (dx,)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(inner.LinSpace)
|
||||||
|
def get_bprop_lin_space(self):
|
||||||
|
"""Grad definition for `LinSpace` operation."""
|
||||||
|
|
||||||
|
def bprop(assist, start, stop, num, out, dout):
|
||||||
|
return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
|
@ -262,3 +262,4 @@ from .tensor_scatter_update import _tensor_scatter_update_tbe
|
||||||
from .inplace_update import _inplace_update_tbe
|
from .inplace_update import _inplace_update_tbe
|
||||||
from .splitv import _split_v_tbe
|
from .splitv import _split_v_tbe
|
||||||
from .in_top_k import _in_top_k_tbe
|
from .in_top_k import _in_top_k_tbe
|
||||||
|
from .lin_space import _lin_space_tbe
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""LinSpace op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
lin_space_op_info = TBERegOp("LinSpace") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("lin_space.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("lin_space") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("broadcast") \
|
||||||
|
.input(0, "assist", False, "required", "all") \
|
||||||
|
.input(1, "start", False, "required", "all") \
|
||||||
|
.input(2, "stop", False, "required", "all") \
|
||||||
|
.input(3, "num", False, "required", "all") \
|
||||||
|
.output(0, "output", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
|
||||||
|
DataType.F32_Default,) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(lin_space_op_info)
|
||||||
|
def _lin_space_tbe():
|
||||||
|
"""LinSpace TBE register"""
|
||||||
|
return
|
|
@ -328,3 +328,42 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
||||||
'dtype': params['dtype'],
|
'dtype': params['dtype'],
|
||||||
'value': None}
|
'value': None}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LinSpace(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Generates values in an interval. And return the corresponding interpolation accroding to assist.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D.
|
||||||
|
- **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
|
||||||
|
- **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
|
||||||
|
- **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value.
|
||||||
|
With shape of 0-D.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape as `assist`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> linspace = P.LinSpace()
|
||||||
|
>>> assist = Tensor([5, 5.5], mindspore.float32)
|
||||||
|
>>> start = Tensor(1, mindspore.float32)
|
||||||
|
>>> stop = Tensor(10, mindspore.float32)
|
||||||
|
>>> num = Tensor(5, mindspore.int32)
|
||||||
|
>>> output = linspace(assist, start, stop, num)
|
||||||
|
[12.25, 13.375]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def infer_shape(self, assist, start, stop, num):
|
||||||
|
return assist
|
||||||
|
|
||||||
|
def infer_dtype(self, assist, start, stop, num):
|
||||||
|
args = {"num": num}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.int32,), self.name)
|
||||||
|
args = {"assist": assist, "start": start, "stop": stop}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
|
||||||
|
return assist
|
||||||
|
|
|
@ -1599,6 +1599,14 @@ test_case_array_ops = [
|
||||||
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
|
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
|
||||||
Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
||||||
'desc_bprop': [[3, 3]]}),
|
'desc_bprop': [[3, 3]]}),
|
||||||
|
('LinSpace', {
|
||||||
|
'block': inner.LinSpace(),
|
||||||
|
'desc_inputs': [Tensor([5, 5.5], mstype.float32),
|
||||||
|
Tensor(1, mstype.float32),
|
||||||
|
Tensor(10, mstype.float32),
|
||||||
|
Tensor(5, mstype.int32)],
|
||||||
|
'skip': ['backward'],
|
||||||
|
}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_other_ops = [
|
test_case_other_ops = [
|
||||||
|
|
Loading…
Reference in New Issue