forked from mindspore-Ecosystem/mindspore
!1973 develop TensorScatterUpdate op and access ge and vm
Merge pull request !1973 from zhangbuxue/develop_TensorScatterUpdate_op_and_access_ge_and_vm
This commit is contained in:
commit
b285c8fa9d
|
@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6";
|
||||||
const char kNameReLU6Grad[] = "ReLU6Grad";
|
const char kNameReLU6Grad[] = "ReLU6Grad";
|
||||||
const char kNameElu[] = "Elu";
|
const char kNameElu[] = "Elu";
|
||||||
const char kNameEluGrad[] = "EluGrad";
|
const char kNameEluGrad[] = "EluGrad";
|
||||||
|
const char kNameTensorScatterUpdate[] = "TensorScatterUpdate";
|
||||||
const char kNameScatterUpdate[] = "ScatterUpdate";
|
const char kNameScatterUpdate[] = "ScatterUpdate";
|
||||||
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
|
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
|
||||||
const char kNameScatterMax[] = "ScatterMax";
|
const char kNameScatterMax[] = "ScatterMax";
|
||||||
|
@ -261,6 +262,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
|
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
|
||||||
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
|
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
|
||||||
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
|
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
|
||||||
|
{string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)},
|
||||||
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
|
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
|
||||||
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
|
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
|
||||||
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
|
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
|
||||||
|
|
|
@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
|
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
|
||||||
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
|
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
|
// TensorScatterUpdate
|
||||||
|
INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||||
|
ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// ScatterUpdate
|
// ScatterUpdate
|
||||||
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||||
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
|
|
|
@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
|
||||||
DECLARE_OP_USE_OUTPUT(ZerosLike)
|
DECLARE_OP_USE_OUTPUT(ZerosLike)
|
||||||
DECLARE_OP_ADAPTER(OnesLike)
|
DECLARE_OP_ADAPTER(OnesLike)
|
||||||
DECLARE_OP_USE_OUTPUT(OnesLike)
|
DECLARE_OP_USE_OUTPUT(OnesLike)
|
||||||
|
DECLARE_OP_ADAPTER(TensorScatterUpdate)
|
||||||
|
DECLARE_OP_USE_OUTPUT(TensorScatterUpdate)
|
||||||
DECLARE_OP_ADAPTER(ScatterUpdate)
|
DECLARE_OP_ADAPTER(ScatterUpdate)
|
||||||
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
|
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
|
||||||
DECLARE_OP_ADAPTER(ScatterNdUpdate)
|
DECLARE_OP_ADAPTER(ScatterNdUpdate)
|
||||||
|
|
|
@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.TensorScatterUpdate)
|
||||||
|
def get_bprop_tensor_scatter_update(self):
|
||||||
|
"""Generate bprop for TensorScatterUpdate"""
|
||||||
|
gather_nd = P.GatherNd()
|
||||||
|
tensor_scatter_update = P.TensorScatterUpdate()
|
||||||
|
|
||||||
|
def bprop(x, indices, update, out, dout):
|
||||||
|
x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
|
||||||
|
update_grad = gather_nd(dout, indices)
|
||||||
|
return x_grad, zeros_like(indices), update_grad
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.Argmax)
|
@bprop_getters.register(P.Argmax)
|
||||||
def get_bprop_argmax(self):
|
def get_bprop_argmax(self):
|
||||||
"""Generate bprop for Argmax"""
|
"""Generate bprop for Argmax"""
|
||||||
|
|
|
@ -255,3 +255,4 @@ from .lamb_next_right import _lamb_next_right_tbe
|
||||||
from .sparse_gather_v2 import _sparse_gather_v2_tbe
|
from .sparse_gather_v2 import _sparse_gather_v2_tbe
|
||||||
from .data_format_dim_map import _data_format_dim_map_tbe
|
from .data_format_dim_map import _data_format_dim_map_tbe
|
||||||
from .histogram_fixed_width import _histogram_fixed_width_tbe
|
from .histogram_fixed_width import _histogram_fixed_width_tbe
|
||||||
|
from .tensor_scatter_update import _tensor_scatter_update_tbe
|
||||||
|
|
|
@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""TensorScatterUpdate op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("tensor_scatter_update.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("tensor_scatter_update") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.input(1, "indices", False, "required", "all") \
|
||||||
|
.input(1, "updates", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(tensor_scatter_update_op_info)
|
||||||
|
def _tensor_scatter_update_tbe():
|
||||||
|
"""TensorScatterUpdate TBE register"""
|
||||||
|
return
|
|
@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
||||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||||
Shape, Size, Slice, Split, EmbeddingLookup,
|
Shape, Size, Slice, Split, EmbeddingLookup,
|
||||||
Squeeze, StridedSlice, Tile,
|
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
|
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
|
||||||
|
@ -212,6 +212,7 @@ __all__ = [
|
||||||
'Pad',
|
'Pad',
|
||||||
'MirrorPad',
|
'MirrorPad',
|
||||||
'GatherNd',
|
'GatherNd',
|
||||||
|
'TensorScatterUpdate',
|
||||||
'ScatterUpdate',
|
'ScatterUpdate',
|
||||||
'ScatterNdUpdate',
|
'ScatterNdUpdate',
|
||||||
'Floor',
|
'Floor',
|
||||||
|
|
|
@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer):
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class TensorScatterUpdate(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Update tensor value by using input indices and value.
|
||||||
|
|
||||||
|
Using given values to update tensor value, along with the input indices.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The target tensor.
|
||||||
|
- **indices** (Tensor) - The index of input tensor whose data type is int32.
|
||||||
|
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||||
|
and update.shape = indices.shape + input_x.shape[1:].
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape and type as `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||||
|
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||||
|
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||||
|
>>> op = P.TensorScatterUpdate()
|
||||||
|
>>> output = op(input_x, indices, update)
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Init TensorScatterUpdate"""
|
||||||
|
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||||
|
validator.check('the dimension of x', len(x_shape),
|
||||||
|
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||||
|
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
|
||||||
|
raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||||
|
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||||
|
args = {"x": x_dtype, "value": value_dtype}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class ScatterUpdate(PrimitiveWithInfer):
|
class ScatterUpdate(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Update tensor value by using input indices and value.
|
Update tensor value by using input indices and value.
|
||||||
|
@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||||
if indices_shape + x_shape[1:] != value_shape:
|
if indices_shape + x_shape[1:] != value_shape:
|
||||||
raise ValueError('Input value are not match with input indices.')
|
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||||
|
@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
||||||
validator.check('the dimension of x', len(x_shape),
|
validator.check('the dimension of x', len(x_shape),
|
||||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||||
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
|
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
|
||||||
raise ValueError('Input value are not match with input indices.')
|
raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||||
|
|
|
@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
|
||||||
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_scatter_update():
|
||||||
|
class TensorScatterUpdateNet(nn.Cell):
|
||||||
|
"""TensorScatterUpdate net definition"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorScatterUpdateNet, self).__init__()
|
||||||
|
self.tensor_scatter_update = P.TensorScatterUpdate()
|
||||||
|
|
||||||
|
def construct(self, x, i, u):
|
||||||
|
out = self.tensor_scatter_update(x, i, u)
|
||||||
|
return out
|
||||||
|
net = TensorScatterUpdateNet()
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
|
||||||
|
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
|
||||||
|
updates = Tensor(np.ones([2, 5], np.float32))
|
||||||
|
net(x, indices, updates)
|
||||||
|
|
||||||
|
|
||||||
class InputBackward(nn.Cell):
|
class InputBackward(nn.Cell):
|
||||||
def __init__(self, network):
|
def __init__(self, network):
|
||||||
super(InputBackward, self).__init__()
|
super(InputBackward, self).__init__()
|
||||||
|
@ -1537,6 +1556,12 @@ test_case_other_ops = [
|
||||||
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
|
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
|
||||||
Tensor(np.ones((2,), np.int32))),
|
Tensor(np.ones((2,), np.int32))),
|
||||||
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
|
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
|
||||||
|
('TensorScatterUpdate', {
|
||||||
|
'block': P.TensorScatterUpdate(),
|
||||||
|
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
|
||||||
|
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||||
|
Tensor(np.ones([2, 5], np.float32) * 99)),
|
||||||
|
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
|
||||||
('ScatterMax', {
|
('ScatterMax', {
|
||||||
'block': ScatterMax(),
|
'block': ScatterMax(),
|
||||||
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
|
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
|
||||||
|
|
Loading…
Reference in New Issue