forked from OSSInnovation/mindspore
!2661 add support dtype for scatter_add vm
Merge pull request !2661 from zhaozhenlong/op/scatter-add
This commit is contained in:
commit
6fb5538117
|
@ -31,6 +31,8 @@ scatter_add_op_info = TBERegOp("ScatterAdd") \
|
|||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -220,10 +220,10 @@ class ScatterMax(nn.Cell):
|
|||
class ScatterAdd(nn.Cell):
|
||||
"""ScatterAdd net definition"""
|
||||
|
||||
def __init__(self, ref_shape):
|
||||
def __init__(self, ref_shape, dtype=np.float32):
|
||||
super(ScatterAdd, self).__init__()
|
||||
self.scatter_add = P.ScatterAdd()
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref")
|
||||
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_add(self.ref, indices, updates)
|
||||
|
@ -1677,12 +1677,37 @@ test_case_other_ops = [
|
|||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAddScalar', {
|
||||
'block': ScatterAdd((6,)),
|
||||
'desc_inputs': (Tensor(np.array([2], np.int32)),
|
||||
Tensor(np.array([2.0], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAdd2d', {
|
||||
'block': ScatterAdd((3, 4)),
|
||||
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
|
||||
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
|
||||
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAddF16', {
|
||||
'block': ScatterAdd((6,), np.float16),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAddI8', {
|
||||
'block': ScatterAdd((6,), np.int8),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.int8))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAddI32', {
|
||||
'block': ScatterAdd((6,), np.int32),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.int32))),
|
||||
'skip': ['backward']}),
|
||||
('ScatterAddU8', {
|
||||
'block': ScatterAdd((6,), np.uint8),
|
||||
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
|
||||
Tensor(np.array([2, 3, 4], np.uint8))),
|
||||
'skip': ['backward']}),
|
||||
('SmoothL1Loss', {
|
||||
'block': P.SmoothL1Loss(),
|
||||
'desc_inputs': [[256, 4], [256, 4]],
|
||||
|
|
Loading…
Reference in New Issue