develop TensorScatterUpdate op and access ge and vm

This commit is contained in:
buxue 2020-06-10 17:26:30 +08:00
parent ac5878b4b7
commit 8aee6b07d6
11 changed files with 137 additions and 5 deletions

View File

@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu";
const char kNameEluGrad[] = "EluGrad";
const char kNameTensorScatterUpdate[] = "TensorScatterUpdate";
const char kNameScatterUpdate[] = "ScatterUpdate";
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
const char kNameScatterMax[] = "ScatterMax";
@ -261,6 +262,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)},
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},

View File

@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
// TensorScatterUpdate
INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}};
// ScatterUpdate
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};

View File

@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_OP_ADAPTER(TensorScatterUpdate)
DECLARE_OP_USE_OUTPUT(TensorScatterUpdate)
DECLARE_OP_ADAPTER(ScatterUpdate)
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
DECLARE_OP_ADAPTER(ScatterNdUpdate)

View File

@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self):
return bprop
@bprop_getters.register(P.TensorScatterUpdate)
def get_bprop_tensor_scatter_update(self):
"""Generate bprop for TensorScatterUpdate"""
gather_nd = P.GatherNd()
tensor_scatter_update = P.TensorScatterUpdate()
def bprop(x, indices, update, out, dout):
x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
update_grad = gather_nd(dout, indices)
return x_grad, zeros_like(indices), update_grad
return bprop
@bprop_getters.register(P.Argmax)
def get_bprop_argmax(self):
"""Generate bprop for Argmax"""

View File

@ -251,3 +251,4 @@ from .lamb_next_right import _lamb_next_right_tbe
from .sparse_gather_v2 import _sparse_gather_v2_tbe
from .data_format_dim_map import _data_format_dim_map_tbe
from .histogram_fixed_width import _histogram_fixed_width_tbe
from .tensor_scatter_update import _tensor_scatter_update_tbe

View File

@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()

View File

@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorScatterUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("tensor_scatter_update.so") \
.compute_cost(10) \
.kernel_name("tensor_scatter_update") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(1, "updates", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(tensor_scatter_update_op_info)
def _tensor_scatter_update_tbe():
"""TensorScatterUpdate TBE register"""
return

View File

@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, EmbeddingLookup,
Squeeze, StridedSlice, Tile,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
@ -211,6 +211,7 @@ __all__ = [
'Pad',
'MirrorPad',
'GatherNd',
'TensorScatterUpdate',
'ScatterUpdate',
'ScatterNdUpdate',
'Floor',

View File

@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer):
return x_dtype
class TensorScatterUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.
Using given values to update tensor value, along with the input indices.
Inputs:
- **input_x** (Tensor) - The target tensor.
- **indices** (Tensor) - The index of input tensor whose data type is int32.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.TensorScatterUpdate()
>>> output = op(input_x, indices, update)
"""
@prim_attr_register
def __init__(self):
"""Init TensorScatterUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
class ScatterUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.
@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer):
def infer_shape(self, x_shape, indices_shape, value_shape):
if indices_shape + x_shape[1:] != value_shape:
raise ValueError('Input value are not match with input indices.')
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError('Input value are not match with input indices.')
raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):

View File

@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
def test_tensor_scatter_update():
class TensorScatterUpdateNet(nn.Cell):
"""TensorScatterUpdate net definition"""
def __init__(self):
super(TensorScatterUpdateNet, self).__init__()
self.tensor_scatter_update = P.TensorScatterUpdate()
def construct(self, x, i, u):
out = self.tensor_scatter_update(x, i, u)
return out
net = TensorScatterUpdateNet()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
updates = Tensor(np.ones([2, 5], np.float32))
net(x, indices, updates)
class InputBackward(nn.Cell):
def __init__(self, network):
super(InputBackward, self).__init__()
@ -1460,6 +1479,12 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
Tensor(np.ones((2,), np.int32))),
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
('TensorScatterUpdate', {
'block': P.TensorScatterUpdate(),
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.ones([2, 5], np.float32) * 99)),
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
('ScatterMax', {
'block': ScatterMax(),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),