!3846 Add TBE op SquaredDifference\Xdivy\Xlogy for VM.
Merge pull request !3846 from liuxiao93/Add-ops-SeluSquaredDifference
This commit is contained in:
commit
aa65cbf733
|
@ -252,6 +252,21 @@ def get_bprop_div_no_nan(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Xdivy)
|
||||
def get_bprop_xdivy(self):
|
||||
"""Grad definition for `Xdivy` operation."""
|
||||
div_op = P.Xdivy()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
x_dtype = F.dtype(x)
|
||||
not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
|
||||
bc_x = div_op(not_zero_x, y) * dout
|
||||
bc_y = div_op(-x, F.square(y)) * dout
|
||||
return binop_grad_common(x, y, bc_x, bc_y)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Floor)
|
||||
def get_bprop_floor(self):
|
||||
"""Grad definition for `floor` operation."""
|
||||
|
@ -353,6 +368,36 @@ def get_bprop_square(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SquaredDifference)
|
||||
def get_bprop_squared_difference(self):
|
||||
"""Grad definition for `SquaredDifference` operation."""
|
||||
neg = P.Neg()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
x_grad = 2 * dout * (x - y)
|
||||
bc_x = x_grad
|
||||
bc_y = neg(x_grad)
|
||||
return binop_grad_common(x, y, bc_x, bc_y)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Xlogy)
|
||||
def get_bprop_xlogy(self):
|
||||
"""Grad definition for `Xlogy` operation."""
|
||||
log_op = P.Xlogy()
|
||||
div_op = P.Xdivy()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
x_dtype = F.dtype(x)
|
||||
not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
|
||||
bc_x = log_op(not_zero_x, y) * dout
|
||||
bc_y = div_op(x, y) * dout
|
||||
return binop_grad_common(x, y, bc_x, bc_y)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Sqrt)
|
||||
def get_bprop_sqrt(self):
|
||||
"""Grad definition for `Sqrt` operation."""
|
||||
|
|
|
@ -108,6 +108,8 @@ from .elu import _elu_tbe
|
|||
from .elu_grad import _elu_grad_tbe
|
||||
from .div import _div_tbe
|
||||
from .log import _log_tbe
|
||||
from .xdivy import _xdivy_tbe
|
||||
from .xlogy import _xlogy_tbe
|
||||
from .floor_div import _floor_div_tbe
|
||||
from .zeros_like import _zeros_like_tbe
|
||||
from .neg import _neg_tbe
|
||||
|
@ -133,6 +135,7 @@ from .softplus import _softplus_tbe
|
|||
from .softplus_grad import _softplus_grad_tbe
|
||||
from .softmax_grad_ext import _softmax_grad_ext_tbe
|
||||
from .square import _square_tbe
|
||||
from .squared_difference import _squared_difference_tbe
|
||||
from .sqrt import _sqrt_tbe
|
||||
from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d
|
||||
from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SquaredDifference op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
squared_difference_op_info = TBERegOp("SquaredDifference") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("squared_difference.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("squared_difference") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("broadcast") \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(squared_difference_op_info)
|
||||
def _squared_difference_tbe():
|
||||
"""SquaredDifference TBE register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Xdivy op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
xdivy_op_info = TBERegOp("Xdivy") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("xdivy.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("xdivy") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("broadcast") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(xdivy_op_info)
|
||||
def _xdivy_tbe():
|
||||
"""Xdivy TBE register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Xlogy op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
xlogy_op_info = TBERegOp("Xlogy") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("xlogy.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("xlogy") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("broadcast") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(xlogy_op_info)
|
||||
def _xlogy_tbe():
|
||||
"""Xlogy TBE register"""
|
||||
return
|
|
@ -51,7 +51,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum, HistogramFixedWidth,
|
||||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
|
||||
|
||||
|
@ -107,6 +107,9 @@ __all__ = [
|
|||
'Rsqrt',
|
||||
'Sqrt',
|
||||
'Square',
|
||||
'SquaredDifference',
|
||||
'Xdivy',
|
||||
'Xlogy',
|
||||
'Conv2D',
|
||||
'Flatten',
|
||||
'MaxPoolWithArgmax',
|
||||
|
|
|
@ -1121,6 +1121,40 @@ class Mul(_MathBinaryOp):
|
|||
return None
|
||||
|
||||
|
||||
class SquaredDifference(_MathBinaryOp):
|
||||
"""
|
||||
Subtracts the second input tensor from the first input tensor element-wise and returns square of it.
|
||||
|
||||
The inputs must be two tensors or one tensor and one scalar.
|
||||
When the inputs are two tensors,
|
||||
both dtypes cannot be bool, and the shapes of them could be broadcast.
|
||||
When the inputs are one tensor and one scalar,
|
||||
the scalar only could be a constant.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
|
||||
a bool or a tensor whose data type is float16, float32, int32 or bool.
|
||||
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
|
||||
a bool when the first input is a tensor or a tensor whose data type is
|
||||
float16, float32, int32 or bool.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is same as the shape after broadcasting,
|
||||
and the data type is the one with high precision or high digits among the two inputs.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
|
||||
>>> input_y = Tensor(np.array([2.0, 4.0, 6.0]), mindspore.float32)
|
||||
>>> squared_difference = P.SquaredDifference()
|
||||
>>> squared_difference(input_x, input_y)
|
||||
[1.0, 4.0, 9.0]
|
||||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
valid_type = [mstype.float16, mstype.float32, mstype.int32]
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, valid_type, self.name)
|
||||
|
||||
|
||||
class Square(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns square of a tensor element-wise.
|
||||
|
@ -1962,6 +1996,72 @@ class Ceil(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class Xdivy(_MathBinaryOp):
|
||||
"""
|
||||
Divide the first input tensor by the second input tensor element-wise. Returns zero when `x` is zero.
|
||||
|
||||
The inputs must be two tensors or one tensor and one scalar.
|
||||
When the inputs are two tensors,
|
||||
both dtypes cannot be bool, and the shapes of them could be broadcast.
|
||||
When the inputs are one tensor and one scalar,
|
||||
the scalar only could be a constant.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
|
||||
a bool or a tensor whose data type is float16, float32 or bool.
|
||||
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
|
||||
a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is same as the shape after broadcasting,
|
||||
and the data type is the one with high precision or high digits among the two inputs.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.float32)
|
||||
>>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32)
|
||||
>>> xdivy = P.Xdivy()
|
||||
>>> xdivy(input_x, input_y)
|
||||
[1.0, 2.0, -0.5]
|
||||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
|
||||
class Xlogy(_MathBinaryOp):
|
||||
"""
|
||||
Computes first input tensor multiplied by the logarithm of second input tensor element-wise.
|
||||
Returns zero when `x` is zero.
|
||||
|
||||
The inputs must be two tensors or one tensor and one scalar.
|
||||
When the inputs are two tensors,
|
||||
both dtypes cannot be bool, and the shapes of them could be broadcast.
|
||||
When the inputs are one tensor and one scalar,
|
||||
the scalar only could be a constant.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
|
||||
a bool or a tensor whose data type is float16, float32 or bool.
|
||||
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
|
||||
a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool.
|
||||
The value must be positive.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is same as the shape after broadcasting,
|
||||
and the data type is the one with high precision or high digits among the two inputs.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([-5, 0, 4]), mindspore.float32)
|
||||
>>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32)
|
||||
>>> xlogy = P.Xlogy()
|
||||
>>> Xlogy(input_x, input_y)
|
||||
[-3.465736, 0.0, 2.7725887]
|
||||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
|
||||
class Acosh(PrimitiveWithInfer):
|
||||
"""
|
||||
Compute inverse hyperbolic cosine of x element-wise.
|
||||
|
|
|
@ -3205,11 +3205,11 @@ class FusedSparseFtrl(PrimitiveWithInfer):
|
|||
use_locking (bool): Use locks for update operation if True . Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter): The variable to be updated. The data type must be float32.
|
||||
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
|
||||
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
|
||||
- **var** (Parameter) - The variable to be updated. The data type must be float32.
|
||||
- **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
|
||||
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The shape
|
||||
of `indices` must be the same as `grad` in first dimension. The type must be int32.
|
||||
|
||||
Outputs:
|
||||
|
@ -3300,9 +3300,9 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
|
||||
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
|
||||
- **lr** (Tensor): The learning rate value. The data type must be float32.
|
||||
- **l1** (Tensor): l1 regularization strength. The data type must be float32.
|
||||
- **l2** (Tensor): l2 regularization strength. The data type must be float32.
|
||||
- **lr** (Tensor) - The learning rate value. The data type must be float32.
|
||||
- **l1** (Tensor) - l1 regularization strength. The data type must be float32.
|
||||
- **l2** (Tensor) - l2 regularization strength. The data type must be float32.
|
||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
|
||||
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
|
||||
must be int32.
|
||||
|
@ -4670,16 +4670,16 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
use_locking (bool): Use locks for update operation if True . Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor): The variable to be updated.
|
||||
- **accum** (Tensor): The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Tensor): The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor): Gradient.
|
||||
- **lr** (Union[Number, Tensor]): The learning rate value, must be positive. Default: 0.001.
|
||||
- **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
|
||||
- **var** (Tensor) - The variable to be updated.
|
||||
- **accum** (Tensor) - The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Tensor) - The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor) - Gradient.
|
||||
- **lr** (Union[Number, Tensor]) - The learning rate value, must be positive. Default: 0.001.
|
||||
- **l1** (Union[Number, Tensor]) - l1 regularization strength, must be greater than or equal to zero.
|
||||
Default: 0.0.
|
||||
- **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
|
||||
- **l2** (Union[Number, Tensor]) - l2 regularization strength, must be greater than or equal to zero.
|
||||
Default: 0.0.
|
||||
- **lr_power** (Union[Number, Tensor]): Learning rate power controls how the learning rate decreases
|
||||
- **lr_power** (Union[Number, Tensor]) - Learning rate power controls how the learning rate decreases
|
||||
during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero.
|
||||
Default: -0.5.
|
||||
|
||||
|
@ -4760,17 +4760,17 @@ class SparseApplyFtrl(PrimitiveWithInfer):
|
|||
use_locking (bool): Use locks for update operation if True . Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter): The variable to be updated. The data type must be float32.
|
||||
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
|
||||
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`.
|
||||
- **var** (Parameter) - The variable to be updated. The data type must be float32.
|
||||
- **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`.
|
||||
- **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`.
|
||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
|
||||
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
|
||||
The shape of `indices` must be the same as `grad` in first dimension. The type must be int32.
|
||||
|
||||
Outputs:
|
||||
- **var** (Tensor): Tensor, has the same shape and type as `var`.
|
||||
- **accum** (Tensor): Tensor, has the same shape and type as `accum`.
|
||||
- **linear** (Tensor): Tensor, has the same shape and type as `linear`.
|
||||
- **var** (Tensor) - Tensor, has the same shape and type as `var`.
|
||||
- **accum** (Tensor) - Tensor, has the same shape and type as `accum`.
|
||||
- **linear** (Tensor) - Tensor, has the same shape and type as `linear`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
|
@ -4858,9 +4858,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
|
|||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor): Tensor, has the same shape and type as `var`.
|
||||
- **accum** (Tensor): Tensor, has the same shape and type as `accum`.
|
||||
- **linear** (Tensor): Tensor, has the same shape and type as `linear`.
|
||||
- **var** (Tensor) - Tensor, has the same shape and type as `var`.
|
||||
- **accum** (Tensor) - Tensor, has the same shape and type as `accum`.
|
||||
- **linear** (Tensor) - Tensor, has the same shape and type as `linear`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
|
|
|
@ -1013,6 +1013,18 @@ test_case_math_ops = [
|
|||
'desc_const': [(0, 3, 1, 2)],
|
||||
'desc_inputs': [],
|
||||
'skip': ['backward']}),
|
||||
('Xdivy', {
|
||||
'block': P.Xdivy(),
|
||||
'desc_inputs': [[4, 5], [2, 3, 4, 5]],
|
||||
'desc_bprop': [[2, 3, 4, 5]]}),
|
||||
('Xlogy', {
|
||||
'block': P.Xlogy(),
|
||||
'desc_inputs': [[4, 5], [2, 3, 4, 5]],
|
||||
'desc_bprop': [[2, 3, 4, 5]]}),
|
||||
('SquaredDifference', {
|
||||
'block': P.SquaredDifference(),
|
||||
'desc_inputs': [[4, 5], [2, 3, 4, 5]],
|
||||
'desc_bprop': [[2, 3, 4, 5]]}),
|
||||
('Square', {
|
||||
'block': P.Square(),
|
||||
'desc_inputs': [[4]],
|
||||
|
|
Loading…
Reference in New Issue