SactterMin operator supports GPU, and add function interface

This commit is contained in:
hujiahui8 2022-04-12 20:15:43 +08:00
parent fcabbe95b3
commit b8ec35c3b7
25 changed files with 254 additions and 35 deletions

View File

@ -305,5 +305,77 @@ MS_REG_GPU_KERNEL_TWO(ScatterMax,
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterFunctorKernelMod, uint8_t, int64_t)
// ScatterMin
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterFunctorKernelMod, float, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterFunctorKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ScatterFunctorKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ScatterFunctorKernelMod, half, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterFunctorKernelMod, int, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterFunctorKernelMod, int, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterFunctorKernelMod, int8_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterFunctorKernelMod, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterFunctorKernelMod, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterMin,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterFunctorKernelMod, uint8_t, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -28,10 +28,8 @@ namespace mindspore {
namespace kernel {
static const std::map<std::string, ScatterFunctorType> kScatterFunctorTypeMap = {
{"ScatterUpdate", SCATTER_FUNC_UPDATE},
{"ScatterAdd", SCATTER_FUNC_ADD},
{"ScatterSub", SCATTER_FUNC_SUB},
{"ScatterMax", SCATTER_FUNC_MAX},
{"ScatterUpdate", SCATTER_FUNC_UPDATE}, {"ScatterAdd", SCATTER_FUNC_ADD}, {"ScatterSub", SCATTER_FUNC_SUB},
{"ScatterMax", SCATTER_FUNC_MAX}, {"ScatterMin", SCATTER_FUNC_MIN},
};
template <typename T, typename S>
@ -61,7 +59,7 @@ class ScatterFunctorKernelMod : public DeprecatedNativeGpuKernelMod {
auto iter = kScatterFunctorTypeMap.find(kernel_name);
if (iter == kScatterFunctorTypeMap.end()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "Only support these scatter functors: ScatterUpdate, ScatterAdd, "
<< "ScatterSub or ScatterMax currently, but got " << kernel_name;
<< "ScatterSub, ScatterMax or ScatterMin currently, but got " << kernel_name;
} else {
scatter_functor_type_ = iter->second;
}

View File

@ -61,6 +61,17 @@ __global__ void ScatterMaxKernel(const size_t inner_size, const size_t updates_s
}
}
template <typename T, typename S>
__global__ void ScatterMinKernel(const size_t inner_size, const size_t updates_size, const S *indices, const T *updates,
T *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
const size_t current_pos = indices[index] * inner_size + offset;
input[current_pos] = updates[pos] < input[current_pos] ? updates[pos] : input[current_pos];
}
}
template <typename T, typename S>
void ScatterFunc(enum ScatterFunctorType func_type, const size_t &inner_size, const size_t &indices_size,
const S *indices, const T *updates, T *input, cudaStream_t cuda_stream) {
@ -78,6 +89,9 @@ void ScatterFunc(enum ScatterFunctorType func_type, const size_t &inner_size, co
case SCATTER_FUNC_MAX:
return ScatterMaxKernel<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size,
indices, updates, input);
case SCATTER_FUNC_MIN:
return ScatterMinKernel<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size,
indices, updates, input);
default:
break;
}

View File

@ -23,6 +23,7 @@ enum ScatterFunctorType {
SCATTER_FUNC_ADD,
SCATTER_FUNC_SUB,
SCATTER_FUNC_MAX,
SCATTER_FUNC_MIN,
SCATTER_FUNC_INVALID_TYPE = 255
};

View File

@ -859,6 +859,17 @@ def get_bprop_scatter_max(self):
return bprop
@bprop_getters.register(P.ScatterMin)
def get_bprop_scatter_min(self):
"""Generate bprop for ScatterMin"""
gather = P.Gather()
def bprop(x, indices, update, out, dout):
return dout, zeros_like(indices), gather(dout, indices, 0)
return bprop
@bprop_getters.register(P.Argmax)
def get_bprop_argmax(self):
"""Generate bprop for Argmax"""

View File

@ -9,6 +9,6 @@ y
bprop.12:x*
bprop.12:out*
bprop.12:dout2
bprop.12:[CNode]14:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
bprop.12:[CNode]14:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -9,6 +9,6 @@ z
bprop.15:x*
bprop.15:out*
bprop.15:dout2
bprop.15:[CNode]17:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
bprop.15:[CNode]17:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -9,6 +9,6 @@ z
bprop.18:x*
bprop.18:out*
bprop.18:dout2
bprop.18:[CNode]20:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.18:[CNode]20:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -9,6 +9,6 @@ z
bprop.24:x*
bprop.24:out*
bprop.24:dout2
bprop.24:[CNode]26:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
bprop.24:[CNode]26:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -5,5 +5,5 @@ m
bprop.1:x*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]2:1:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
bprop.1:[CNode]2:1:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:2S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ s
bprop.6:x*
bprop.6:out*
bprop.6:dout2
bprop.6:[CNode]8:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.6:[CNode]8:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ s
bprop.3:x*
bprop.3:out*
bprop.3:dout2
bprop.3:[CNode]5:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.3:[CNode]5:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -9,6 +9,6 @@ z
bprop.27:x*
bprop.27:out*
bprop.27:dout2
bprop.27:[CNode]29:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.27:[CNode]29:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -27,9 +27,9 @@ bprop.30:x*
bprop.30:y*
bprop.30:out*
bprop.30:dout2
bprop.30:[CNode]36:8:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:9S-Prim-MakeTuplebv
bprop.30:[CNode]36:8:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pbv
S-Prim-Select:5 S-Prim-Select
output_names€Š Zoutput€3
input_names€ŠZ condition€ŠZx€ŠZy€h
input_names€ŠZ condition€ŠZx€ŠZy€bH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:9S-Prim-MakeTupleh

View File

@ -9,6 +9,6 @@ z
bprop.21:x*
bprop.21:out*
bprop.21:dout2
bprop.21:[CNode]23:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
bprop.21:[CNode]23:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -7,6 +7,6 @@ v
bprop.9:x*
bprop.9:out*
bprop.9:dout2
bprop.9:[CNode]11:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.9:[CNode]11:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -29,9 +29,9 @@ serializable_bprop_ops = [P.ReLU, P.Identity, inner.Range, P.OnesLike, P.ZerosLi
P.LinSpace, P.DropoutGenMask, P.OneHot, P.Assign, P.IOU, P.BNTrainingReduce, P.Equal,
P.NotEqual, P.Greater, P.GreaterEqual, P.Less, P.LessEqual, P.LogicalAnd, P.LogicalOr,
P.ReduceAll, P.ReduceAny, P.DropoutDoMask, P.DType, P.Shape, P.DynamicShape, P.Rank,
P.Select, P.ScatterMax, G.ReluGrad, _constants.kTupleGetItem, P.FloorDiv, P.TruncateDiv,
P.Minimum, P.Maximum, P.IsNan, P.IsInf, P.ReLUV2, "Depend", "stop_gradient", "Switch",
"UpdateState", "Load"]
P.Select, P.ScatterMax, P.ScatterMin, G.ReluGrad, _constants.kTupleGetItem, P.FloorDiv,
P.TruncateDiv, P.Minimum, P.Maximum, P.IsNan, P.IsInf, P.ReLUV2, "Depend", "stop_gradient",
"Switch", "UpdateState", "Load"]
def run_generate():

View File

@ -22,7 +22,8 @@ A collection of function to build neural networks or to compute functions.
from . import array_func, parameter_func, math_func
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill)
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill,
scatter_min)
from .parameter_func import assign, assign_add, assign_sub, index_add
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,

View File

@ -587,6 +587,49 @@ def transpose(input_x, input_perm):
return transpose_(input_x, input_perm)
def scatter_min(input_x, indices, updates, use_locking=False):
r"""
Using given values to update tensor value through the min operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Args:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
- use_locking (bool): Whether to protect the assignment by a lock. Default: False.
Outputs:
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
Raises:
TypeError: If `indices` is not an int32.
TypeError: If `use_locking` is not a bool.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, Parameter
>>> from mindspore import ops
>>> input_x = Parameter(Tensor(np.zeros((2, 3)), mindspore.float32), name="input_x")
>>> indices = Tensor(np.array([1, 0]), mindspore.int32)
>>> update = Tensor(np.arange(6).reshape((2, 3)), mindspore.float32)
>>> scatter_min = ops.ScatterMin()
>>> output = scatter_min(input_x, indices, update)
>>> print(output)
[[0. 0. 0.]
[0. 0. 0.]]
"""
return P.ScatterMin(use_locking)(input_x, indices, updates)
scatter_nd_ = P.ScatterNd()
def scatter_nd(indices, updates, shape):
r"""
@ -1040,6 +1083,7 @@ __all__ = [
'gather',
'gather_d',
'gather_nd',
'masked_fill'
'masked_fill',
'scatter_min'
]
__all__.sort()

View File

@ -980,6 +980,7 @@ tensor_operator_registry.register('logical_not', P.LogicalNot)
tensor_operator_registry.register('sum', P.ReduceSum)
tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('index_add', P.IndexAdd)
tensor_operator_registry.register('scatter_min', P.ScatterMin)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)

View File

@ -4301,7 +4301,7 @@ class ScatterMin(_ScatterOp):
is required when data type conversion of Parameter is not supported.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32),

View File

@ -29,6 +29,7 @@ func_map = {
"add": P.ScatterAdd,
"sub": P.ScatterSub,
"max": P.ScatterMax,
"min": P.ScatterMin,
}
@ -130,6 +131,11 @@ def test_scatter_func_small_float32():
expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -164,6 +170,12 @@ def test_scatter_func_input_updated():
expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
# min
net = TestScatterFuncNet("min", lock, inputx, indices, updates)
net()
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -329,6 +341,12 @@ def test_scatter_func_large_shape_float32():
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = np.ones((4, 2, 3, 4)).astype(np.float32)
expected[0][0][0][0] = 0.0
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -358,6 +376,11 @@ def test_scatter_func_small_float32_use_locking_false():
expected = np.array([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_use_locking_false_net("min", inputx, indices, updates)
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -418,6 +441,11 @@ def test_scatter_func_input_less_than_1_float32():
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = inputx.asnumpy()
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -447,6 +475,11 @@ def test_scatter_func_float16():
expected = np.array([[6.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -496,6 +529,11 @@ def test_scatter_func_large_float16():
])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = np.zeros((2, 3, 4)).astype(np.float16)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -541,6 +579,11 @@ def test_scatter_func_disordered_float16():
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16))
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -590,6 +633,11 @@ def test_scatter_func_large_int32():
])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = np.zeros((2, 3, 4)).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -635,6 +683,11 @@ def test_scatter_func_disordered_int32():
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training

View File

@ -385,6 +385,30 @@ def test_scatter_max():
grad.compile(indices, updates)
def test_scatter_min():
"""
Feature: Bprop pre-compilation.
Description: Compile the backward graph for the scatter_min op.
Expectation: Load the bprop mindir successfully.
"""
class ScatterMinNet(nn.Cell):
def __init__(self):
super(ScatterMinNet, self).__init__()
self.scatter_min = P.ScatterMin()
self.input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mstype.float32),
name="input_x")
def construct(self, indices, updates):
return self.scatter_min(self.input_x, indices, updates)
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
updates = Tensor(np.ones([2, 2, 3]) * 88, mstype.float32)
scatter_min = ScatterMinNet()
grad = GradNet(scatter_min)
grad.compile(indices, updates)
def test_relu_grad():
"""
Feature: Bprop pre-compilation.