!2787 Add op ScatterMul ScatterDiv vm

Merge pull request !2787 from zhaozhenlong/op/scatter-div-mul-vm
This commit is contained in:
mindspore-ci-bot 2020-07-01 20:41:03 +08:00 committed by Gitee
commit 622b97f3b6
7 changed files with 301 additions and 84 deletions

View File

@ -369,7 +369,7 @@ class CosineEmbeddingLoss(_Loss):
>>> x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]), mindspore.float32)
>>> y = Tensor(np.array([1,-1]), mindspore.int32)
>>> cosine_embedding_loss = P.CosineEmbeddingLoss()
>>> cosine_embedding_loss(x1, x2, target)
>>> cosine_embedding_loss(x1, x2, y)
[0.0003426671]
"""
def __init__(self, margin=0.0, reduction="mean"):

View File

@ -276,4 +276,6 @@ from .lrn_grad import _lrn_grad_tbe
from .scatter_max import _scatter_max_tbe
from .scatter_min import _scatter_min_tbe
from .scatter_sub import _scatter_sub_tbe
from .scatter_mul import _scatter_mul_tbe
from .scatter_div import _scatter_div_tbe
from .mod import _mod_tbe

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterDiv op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_div_op_info = TBERegOp("ScatterDiv") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_div.so") \
.compute_cost(10) \
.kernel_name("scatter_div") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(scatter_div_op_info)
def _scatter_div_tbe():
"""ScatterDiv TBE register"""
return

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterMul op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_mul_op_info = TBERegOp("ScatterMul") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_mul.so") \
.compute_cost(10) \
.kernel_name("scatter_mul") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(scatter_mul_op_info)
def _scatter_mul_tbe():
"""ScatterMul TBE register"""
return

View File

@ -25,8 +25,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMax, ScatterMin, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
@ -215,6 +215,8 @@ __all__ = [
'L2Normalize',
'ScatterAdd',
'ScatterSub',
'ScatterMul',
'ScatterDiv',
'ScatterNd',
'ScatterMax',
'ScatterMin',

View File

@ -38,6 +38,39 @@ from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import typing
class _ScatterOp(PrimitiveWithInfer):
"""
Define Scatter operators
"""
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@staticmethod
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
@prim_attr_register
def __init__(self, use_locking=False):
"""Init _ScatterOp"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, updates_shape):
_ScatterOp._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
validator.check_value_type('axis', axis, [int, tuple], prim_name)
@ -2221,7 +2254,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_dtype
class ScatterUpdate(PrimitiveWithInfer):
class ScatterUpdate(_ScatterOp):
"""
Update tensor value by using input indices and value.
@ -2233,8 +2266,8 @@ class ScatterUpdate(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor. With int32 data type.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
and updates.shape = indices.shape + input_x.shape[1:].
Outputs:
Tensor, has the same shape and type as `input_x`.
@ -2243,27 +2276,17 @@ class ScatterUpdate(PrimitiveWithInfer):
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> update = Tensor(np_update, mindspore.float32)
>>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> updates = Tensor(np_updates, mindspore.float32)
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update)
>>> output = op(input_x, indices, updates)
"""
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
if indices_shape + x_shape[1:] != value_shape:
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
return x_shape
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
@ -2323,14 +2346,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_dtype
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
class ScatterMax(PrimitiveWithInfer):
class ScatterMax(_ScatterOp):
"""
Update the value of the input tensor through the max operation.
@ -2364,18 +2380,8 @@ class ScatterMax(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class ScatterMin(PrimitiveWithInfer):
class ScatterMin(_ScatterOp):
"""
Update the value of the input tensor through the min operation.
@ -2403,24 +2409,8 @@ class ScatterMin(PrimitiveWithInfer):
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Init ScatterMin"""
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class ScatterAdd(PrimitiveWithInfer):
class ScatterAdd(_ScatterOp):
"""
Update the value of the input tensor through the add operation.
@ -2448,23 +2438,8 @@ class ScatterAdd(PrimitiveWithInfer):
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Init ScatterAdd"""
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {'x': x_dtype, 'updates': updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class ScatterSub(PrimitiveWithInfer):
class ScatterSub(_ScatterOp):
"""
Update the value of the input tensor through the sub operation.
@ -2492,20 +2467,63 @@ class ScatterSub(PrimitiveWithInfer):
[[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Init ScatterSub"""
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
class ScatterMul(_ScatterOp):
"""
Update the value of the input tensor through the mul operation.
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {'x': x_dtype, 'updates': updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
Using given values to update tensor value through the mul operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do mul operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the mul operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Parameter, the updated `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_mul = P.ScatterMul()
>>> output = scatter_mul(input_x, indices, updates)
[[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]
"""
class ScatterDiv(_ScatterOp):
"""
Update the value of the input tensor through the div operation.
Using given values to update tensor value through the div operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do div operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the div operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Parameter, the updated `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
>>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_div = P.ScatterDiv()
>>> output = scatter_div(input_x, indices, updates)
[[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]]
"""
class SpaceToDepth(PrimitiveWithInfer):

View File

@ -185,6 +185,19 @@ class HistogramSummaryNet(nn.Cell):
return out
class ScatterUpdate(nn.Cell):
"""ScatterUpdate net definition"""
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
super(ScatterUpdate, self).__init__()
self.scatter_update = P.ScatterUpdate(use_locking)
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_update(self.ref, indices, updates)
return out
class ScatterMax(nn.Cell):
"""ScatterMax net definition"""
@ -237,6 +250,32 @@ class ScatterSub(nn.Cell):
return out
class ScatterMul(nn.Cell):
"""ScatterMul net definition"""
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
super(ScatterMul, self).__init__()
self.scatter_mul = P.ScatterMul(use_locking)
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_mul(self.ref, indices, updates)
return out
class ScatterDiv(nn.Cell):
"""ScatterDiv net definition"""
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
super(ScatterDiv, self).__init__()
self.scatter_div = P.ScatterDiv(use_locking)
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)*10), name="ref")
def construct(self, indices, updates):
out = self.scatter_div(self.ref, indices, updates)
return out
class ApplyFtrlNet(nn.Cell):
def __init__(self):
super(ApplyFtrlNet, self).__init__()
@ -1861,6 +1900,11 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.int32))),
'skip': ['backward']}),
('ScatterUpdate', {
'block': ScatterUpdate((6,)),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}),
('ScatterAddUseLocking', {
'block': ScatterAdd((6,), use_locking=True),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
@ -1902,6 +1946,73 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.uint8))),
'skip': ['backward']}),
('ScatterMulUseLocking', {
'block': ScatterMul((6,), use_locking=True),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterMulScalar', {
'block': ScatterMul((6,)),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterMul2d', {
'block': ScatterMul((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}),
('ScatterMulF16', {
'block': ScatterMul((6,), np.float16),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
'skip': ['backward']}),
('ScatterMulI8', {
'block': ScatterMul((6,), np.int8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int8))),
'skip': ['backward']}),
('ScatterMulI32', {
'block': ScatterMul((6,), np.int32),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int32))),
'skip': ['backward']}),
('ScatterMulU8', {
'block': ScatterMul((6,), np.uint8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.uint8))),
'skip': ['backward']}),
('ScatterDivUseLocking', {
'block': ScatterDiv((6,), use_locking=True),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterDivScalar', {
'block': ScatterDiv((6,)),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterDiv2d', {
'block': ScatterDiv((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}),
('ScatterDivF16', {
'block': ScatterDiv((6,), np.float16),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
'skip': ['backward']}),
('ScatterDivI8', {
'block': ScatterDiv((6,), np.int8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int8))),
'skip': ['backward']}),
('ScatterDivU8', {
'block': ScatterDiv((6,), np.uint8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.uint8))),
'skip': ['backward']}),
('ScatterSubUseLocking', {
'block': ScatterSub((6,), use_locking=True),
'desc_inputs': (Tensor(np.array([2], np.int32)),