forked from mindspore-Ecosystem/mindspore
develop TensorScatterUpdate op and access ge and vm
This commit is contained in:
parent
ac5878b4b7
commit
8aee6b07d6
|
@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6";
|
|||
const char kNameReLU6Grad[] = "ReLU6Grad";
|
||||
const char kNameElu[] = "Elu";
|
||||
const char kNameEluGrad[] = "EluGrad";
|
||||
const char kNameTensorScatterUpdate[] = "TensorScatterUpdate";
|
||||
const char kNameScatterUpdate[] = "ScatterUpdate";
|
||||
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
|
||||
const char kNameScatterMax[] = "ScatterMax";
|
||||
|
@ -261,6 +262,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
|
||||
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
|
||||
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
|
||||
{string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)},
|
||||
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
|
||||
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
|
||||
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
|
||||
|
|
|
@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
|
|||
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
|
||||
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||
|
||||
// TensorScatterUpdate
|
||||
INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||
ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// ScatterUpdate
|
||||
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
|
|
|
@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
|
|||
DECLARE_OP_USE_OUTPUT(ZerosLike)
|
||||
DECLARE_OP_ADAPTER(OnesLike)
|
||||
DECLARE_OP_USE_OUTPUT(OnesLike)
|
||||
DECLARE_OP_ADAPTER(TensorScatterUpdate)
|
||||
DECLARE_OP_USE_OUTPUT(TensorScatterUpdate)
|
||||
DECLARE_OP_ADAPTER(ScatterUpdate)
|
||||
DECLARE_OP_USE_OUTPUT(ScatterUpdate)
|
||||
DECLARE_OP_ADAPTER(ScatterNdUpdate)
|
||||
|
|
|
@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.TensorScatterUpdate)
|
||||
def get_bprop_tensor_scatter_update(self):
|
||||
"""Generate bprop for TensorScatterUpdate"""
|
||||
gather_nd = P.GatherNd()
|
||||
tensor_scatter_update = P.TensorScatterUpdate()
|
||||
|
||||
def bprop(x, indices, update, out, dout):
|
||||
x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
|
||||
update_grad = gather_nd(dout, indices)
|
||||
return x_grad, zeros_like(indices), update_grad
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Argmax)
|
||||
def get_bprop_argmax(self):
|
||||
"""Generate bprop for Argmax"""
|
||||
|
|
|
@ -251,3 +251,4 @@ from .lamb_next_right import _lamb_next_right_tbe
|
|||
from .sparse_gather_v2 import _sparse_gather_v2_tbe
|
||||
from .data_format_dim_map import _data_format_dim_map_tbe
|
||||
from .histogram_fixed_width import _histogram_fixed_width_tbe
|
||||
from .tensor_scatter_update import _tensor_scatter_update_tbe
|
||||
|
|
|
@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
|
|||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \
|
|||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""TensorScatterUpdate op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("tensor_scatter_update.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("tensor_scatter_update") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(1, "updates", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(tensor_scatter_update_op_info)
|
||||
def _tensor_scatter_update_tbe():
|
||||
"""TensorScatterUpdate TBE register"""
|
||||
return
|
|
@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
|
||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split, EmbeddingLookup,
|
||||
Squeeze, StridedSlice, Tile,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo)
|
||||
|
@ -211,6 +211,7 @@ __all__ = [
|
|||
'Pad',
|
||||
'MirrorPad',
|
||||
'GatherNd',
|
||||
'TensorScatterUpdate',
|
||||
'ScatterUpdate',
|
||||
'ScatterNdUpdate',
|
||||
'Floor',
|
||||
|
|
|
@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class TensorScatterUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
Update tensor value by using input indices and value.
|
||||
|
||||
Using given values to update tensor value, along with the input indices.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor.
|
||||
- **indices** (Tensor) - The index of input tensor whose data type is int32.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape = indices.shape + input_x.shape[1:].
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||
>>> op = P.TensorScatterUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Init TensorScatterUpdate"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
validator.check('the dimension of x', len(x_shape),
|
||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
|
||||
raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
|
||||
args = {"x": x_dtype, "value": value_dtype}
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ScatterUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
Update tensor value by using input indices and value.
|
||||
|
@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
if indices_shape + x_shape[1:] != value_shape:
|
||||
raise ValueError('Input value are not match with input indices.')
|
||||
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
|
@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
validator.check('the dimension of x', len(x_shape),
|
||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
|
||||
raise ValueError('Input value are not match with input indices.')
|
||||
raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
|
|
|
@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
|
|||
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
||||
|
||||
|
||||
def test_tensor_scatter_update():
|
||||
class TensorScatterUpdateNet(nn.Cell):
|
||||
"""TensorScatterUpdate net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(TensorScatterUpdateNet, self).__init__()
|
||||
self.tensor_scatter_update = P.TensorScatterUpdate()
|
||||
|
||||
def construct(self, x, i, u):
|
||||
out = self.tensor_scatter_update(x, i, u)
|
||||
return out
|
||||
net = TensorScatterUpdateNet()
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
|
||||
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
|
||||
updates = Tensor(np.ones([2, 5], np.float32))
|
||||
net(x, indices, updates)
|
||||
|
||||
|
||||
class InputBackward(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(InputBackward, self).__init__()
|
||||
|
@ -1460,6 +1479,12 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
|
||||
Tensor(np.ones((2,), np.int32))),
|
||||
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
|
||||
('TensorScatterUpdate', {
|
||||
'block': P.TensorScatterUpdate(),
|
||||
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
|
||||
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||
Tensor(np.ones([2, 5], np.float32) * 99)),
|
||||
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
|
||||
('ScatterMax', {
|
||||
'block': ScatterMax(),
|
||||
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
|
||||
|
|
Loading…
Reference in New Issue