forked from mindspore-Ecosystem/mindspore
develop op ScatterMax and dock ge process
This commit is contained in:
parent
a5572f1517
commit
ac86996746
|
@ -102,6 +102,7 @@ const char kNameReLU6Grad[] = "ReLU6Grad";
|
|||
const char kNameElu[] = "Elu";
|
||||
const char kNameEluGrad[] = "EluGrad";
|
||||
const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
|
||||
const char kNameScatterMax[] = "ScatterMax";
|
||||
const char kNameNMSWithMask[] = "NMSWithMask";
|
||||
const char kNameCheckValid[] = "CheckValid";
|
||||
const char kNameSmoothL1Loss[] = "SmoothL1Loss";
|
||||
|
@ -253,6 +254,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
|
||||
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
|
||||
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
|
||||
{string(kNameScatterMax), ADPT_DESC(ScatterMax)},
|
||||
{string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)},
|
||||
{string(kNameCheckValid), ADPT_DESC(CheckValid)},
|
||||
{string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)},
|
||||
|
|
|
@ -530,6 +530,11 @@ INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3
|
|||
ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// ScatterMax
|
||||
INPUT_MAP(ScatterMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
|
||||
ATTR_MAP(ScatterMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ScatterMax) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// CheckValid
|
||||
INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}};
|
||||
ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP;
|
||||
|
|
|
@ -136,6 +136,8 @@ DECLARE_OP_ADAPTER(OnesLike)
|
|||
DECLARE_OP_USE_OUTPUT(OnesLike)
|
||||
DECLARE_OP_ADAPTER(ScatterNdUpdate)
|
||||
DECLARE_OP_USE_OUTPUT(ScatterNdUpdate)
|
||||
DECLARE_OP_ADAPTER(ScatterMax)
|
||||
DECLARE_OP_USE_OUTPUT(ScatterMax)
|
||||
DECLARE_OP_ADAPTER(NMSWithMask)
|
||||
DECLARE_OP_USE_OUTPUT(NMSWithMask)
|
||||
DECLARE_OP_ADAPTER(Unpack)
|
||||
|
|
|
@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Fill, GatherNd, GatherV2, InvertPermutation,
|
||||
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
|
||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||
SameTypeShape,
|
||||
SameTypeShape, ScatterMax,
|
||||
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split,
|
||||
Squeeze, StridedSlice, Tile,
|
||||
|
@ -184,6 +184,7 @@ __all__ = [
|
|||
'BoundingBoxDecode',
|
||||
'L2Normalize',
|
||||
'ScatterNd',
|
||||
'ScatterMax',
|
||||
'ResizeNearestNeighbor',
|
||||
'Pad',
|
||||
'MirrorPad',
|
||||
|
|
|
@ -1953,7 +1953,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
Using given values to update tensor value, along with the input indices.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Defaule: True.
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor.
|
||||
|
@ -1995,6 +1995,53 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class ScatterMax(PrimitiveWithInfer):
|
||||
"""
|
||||
Update the value of the input tensor through the max operation.
|
||||
|
||||
Using given values to update tensor value through the max operation, along with the input indices,.
|
||||
|
||||
Args:
|
||||
use_locking (bool): Whether protect the assignment by a lock. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor.
|
||||
- **indices** (Tensor) - The index to do max operation whose data type should be int.
|
||||
- **updates** (Tensor) - The tensor doing the maximum operation with 'input_x',
|
||||
the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and data type as `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
>>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
|
||||
>>> scatter_max = P.ScatterMax()
|
||||
>>> output = scatter_max(input_x, indices, update)
|
||||
[[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
"""Init ScatterMax"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
|
||||
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, updates_shape):
|
||||
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
||||
raise ValueError(f"For '{self.name}', the shape of update should be [] or "
|
||||
f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
||||
f"indices_shape: {indices_shape}, update_shape: {updates_shape}.")
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
|
||||
args = {"x": x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SpaceToDepth(PrimitiveWithInfer):
|
||||
r"""
|
||||
Rearrange blocks of spatial data into depth.
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
""" test ops """
|
||||
import functools
|
||||
import numpy as np
|
||||
from mindspore import ops
|
||||
from mindspore import ops, Parameter, context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
@ -150,7 +150,7 @@ class CumSumNet(nn.Cell):
|
|||
|
||||
|
||||
class SummaryNet(nn.Cell):
|
||||
def __init__(self,):
|
||||
def __init__(self):
|
||||
super(SummaryNet, self).__init__()
|
||||
self.s = P.ScalarSummary()
|
||||
self.add = P.TensorAdd()
|
||||
|
@ -161,7 +161,7 @@ class SummaryNet(nn.Cell):
|
|||
|
||||
|
||||
class HistogramSummaryNet(nn.Cell):
|
||||
def __init__(self,):
|
||||
def __init__(self):
|
||||
super(HistogramSummaryNet, self).__init__()
|
||||
self.summary = P.HistogramSummary()
|
||||
self.add = P.TensorAdd()
|
||||
|
@ -173,6 +173,19 @@ class HistogramSummaryNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ScatterMax(nn.Cell):
|
||||
"""ScatterMax net definition"""
|
||||
|
||||
def __init__(self):
|
||||
super(ScatterMax, self).__init__()
|
||||
self.scatter_max = P.ScatterMax()
|
||||
self.ref = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], np.float32)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_max(self.ref, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
test_case_math_ops = [
|
||||
('Neg', {
|
||||
'block': P.Neg(),
|
||||
|
@ -833,7 +846,8 @@ test_case_nn_ops = [
|
|||
'block': CumSumNet(),
|
||||
'desc_const': [0],
|
||||
'desc_inputs': [Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))],
|
||||
'desc_bprop': [Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float16))]}),
|
||||
'desc_bprop': [
|
||||
Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float16))]}),
|
||||
('OneHot', {
|
||||
'block': P.OneHot(),
|
||||
'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)],
|
||||
|
@ -1117,6 +1131,11 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
|
||||
Tensor(np.ones((2,), np.int32))),
|
||||
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
|
||||
('ScatterMax', {
|
||||
'block': ScatterMax(),
|
||||
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
|
||||
Tensor(np.ones([2, 2, 3], np.float32) * 99)),
|
||||
'skip': ['backward']}),
|
||||
('SmoothL1Loss', {
|
||||
'block': P.SmoothL1Loss(),
|
||||
'desc_inputs': [[256, 4], [256, 4]],
|
||||
|
@ -1165,12 +1184,10 @@ test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
|
|||
'backward' not in x[1]['skip'], test_case)
|
||||
|
||||
|
||||
import mindspore.context as context
|
||||
|
||||
@non_graph_engine
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_exec():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
return test_exec_case
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue