diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc index 9ce28078f35..0da85534fe9 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.cc @@ -305,5 +305,77 @@ MS_REG_GPU_KERNEL_TWO(ScatterMax, .AddInputAttr(kNumberTypeUInt8) .AddOutputAttr(kNumberTypeUInt8), ScatterFunctorKernelMod, uint8_t, int64_t) + +// ScatterMin +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ScatterFunctorKernelMod, float, int) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ScatterFunctorKernelMod, float, int64_t) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + ScatterFunctorKernelMod, half, int) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + ScatterFunctorKernelMod, half, int64_t) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + ScatterFunctorKernelMod, int, int) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + ScatterFunctorKernelMod, int, int64_t) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + ScatterFunctorKernelMod, int8_t, int) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + ScatterFunctorKernelMod, int8_t, int64_t) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + ScatterFunctorKernelMod, uint8_t, int) +MS_REG_GPU_KERNEL_TWO(ScatterMin, + KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + ScatterFunctorKernelMod, uint8_t, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.h index c8ea77cc4a1..761d9127fde 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_functor_gpu_kernel.h @@ -28,10 +28,8 @@ namespace mindspore { namespace kernel { static const std::map kScatterFunctorTypeMap = { - {"ScatterUpdate", SCATTER_FUNC_UPDATE}, - {"ScatterAdd", SCATTER_FUNC_ADD}, - {"ScatterSub", SCATTER_FUNC_SUB}, - {"ScatterMax", SCATTER_FUNC_MAX}, + {"ScatterUpdate", SCATTER_FUNC_UPDATE}, {"ScatterAdd", SCATTER_FUNC_ADD}, {"ScatterSub", SCATTER_FUNC_SUB}, + {"ScatterMax", SCATTER_FUNC_MAX}, {"ScatterMin", SCATTER_FUNC_MIN}, }; template @@ -61,7 +59,7 @@ class ScatterFunctorKernelMod : public DeprecatedNativeGpuKernelMod { auto iter = kScatterFunctorTypeMap.find(kernel_name); if (iter == kScatterFunctorTypeMap.end()) { MS_LOG(EXCEPTION) << "For '" << kernel_name << "Only support these scatter functors: ScatterUpdate, ScatterAdd, " - << "ScatterSub or ScatterMax currently, but got " << kernel_name; + << "ScatterSub, ScatterMax or ScatterMin currently, but got " << kernel_name; } else { scatter_functor_type_ = iter->second; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu index aa2e18631a0..11f80ca2749 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cu @@ -61,6 +61,17 @@ __global__ void ScatterMaxKernel(const size_t inner_size, const size_t updates_s } } +template +__global__ void ScatterMinKernel(const size_t inner_size, const size_t updates_size, const S *indices, const T *updates, + T *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + const size_t current_pos = indices[index] * inner_size + offset; + input[current_pos] = updates[pos] < input[current_pos] ? updates[pos] : input[current_pos]; + } +} + template void ScatterFunc(enum ScatterFunctorType func_type, const size_t &inner_size, const size_t &indices_size, const S *indices, const T *updates, T *input, cudaStream_t cuda_stream) { @@ -78,6 +89,9 @@ void ScatterFunc(enum ScatterFunctorType func_type, const size_t &inner_size, co case SCATTER_FUNC_MAX: return ScatterMaxKernel<<>>(inner_size, updates_size, indices, updates, input); + case SCATTER_FUNC_MIN: + return ScatterMinKernel<<>>(inner_size, updates_size, + indices, updates, input); default: break; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cuh index 663eee7ab56..f053a404598 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_functor_impl.cuh @@ -23,6 +23,7 @@ enum ScatterFunctorType { SCATTER_FUNC_ADD, SCATTER_FUNC_SUB, SCATTER_FUNC_MAX, + SCATTER_FUNC_MIN, SCATTER_FUNC_INVALID_TYPE = 255 }; diff --git a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py index c2991d2975b..a6e69c65de8 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py @@ -859,6 +859,17 @@ def get_bprop_scatter_max(self): return bprop +@bprop_getters.register(P.ScatterMin) +def get_bprop_scatter_min(self): + """Generate bprop for ScatterMin""" + gather = P.Gather() + + def bprop(x, indices, update, out, dout): + return dout, zeros_like(indices), gather(dout, indices, 0) + + return bprop + + @bprop_getters.register(P.Argmax) def get_bprop_argmax(self): """Generate bprop for Argmax""" diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Argmax_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Argmax_bprop.mindir index f4e2801208c..d77520a9cb8 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/Argmax_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/Argmax_bprop.mindir @@ -9,6 +9,6 @@ y bprop.12:x* bprop.12:out* bprop.12:dout2 -bprop.12:[CNode]14:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb& +bprop.12:[CNode]14:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb& S-Prim-MakeTuple:4S-Prim-MakeTuplebH #S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Argmin_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Argmin_bprop.mindir index 6c8fe7536af..96492dab288 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/Argmin_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/Argmin_bprop.mindir @@ -9,6 +9,6 @@ z bprop.15:x* bprop.15:out* bprop.15:dout2 -bprop.15:[CNode]17:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb& +bprop.15:[CNode]17:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb& S-Prim-MakeTuple:4S-Prim-MakeTuplebH #S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/DType_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/DType_bprop.mindir index 23957fa99f6..e030e629db6 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/DType_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/DType_bprop.mindir @@ -9,6 +9,6 @@ z bprop.18:x* bprop.18:out* bprop.18:dout2 -bprop.18:[CNode]20:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH -#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& -S-Prim-MakeTuple:4S-Prim-MakeTupleh \ No newline at end of file +bprop.18:[CNode]20:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb& +S-Prim-MakeTuple:4S-Prim-MakeTuplebH +#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir index d4aa25a82a4..60b58cddc34 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir @@ -9,6 +9,6 @@ z bprop.24:x* bprop.24:out* bprop.24:dout2 -bprop.24:[CNode]26:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH +bprop.24:[CNode]26:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH #S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& S-Prim-MakeTuple:4S-Prim-MakeTupleh \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Identity_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Identity_bprop.mindir index ce270a2c50e..89b7061a5e6 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/Identity_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/Identity_bprop.mindir @@ -5,5 +5,5 @@ m bprop.1:x* bprop.1:out* bprop.1:dout2 -bprop.1:[CNode]2:1:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb& +bprop.1:[CNode]2:1:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb& S-Prim-MakeTuple:2S-Prim-MakeTupleh \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/OnesLike_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/OnesLike_bprop.mindir index ba57685bddf..e16c6cdef12 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/OnesLike_bprop.mindir @@ -7,6 +7,6 @@ s bprop.6:x* bprop.6:out* bprop.6:dout2 -bprop.6:[CNode]8:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb& -S-Prim-MakeTuple:4S-Prim-MakeTuplebH -#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file +bprop.6:[CNode]8:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH +#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& +S-Prim-MakeTuple:4S-Prim-MakeTupleh \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Range_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Range_bprop.mindir index 1a58e426ffd..d40c5d4ee3d 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/Range_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/Range_bprop.mindir @@ -7,6 +7,6 @@ s bprop.3:x* bprop.3:out* bprop.3:dout2 -bprop.3:[CNode]5:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb& -S-Prim-MakeTuple:4S-Prim-MakeTuplebH -#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file +bprop.3:[CNode]5:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH +#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& +S-Prim-MakeTuple:4S-Prim-MakeTupleh \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Rank_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Rank_bprop.mindir index 73743007f86..b9b1149d4bd 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/Rank_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/Rank_bprop.mindir @@ -9,6 +9,6 @@ z bprop.27:x* bprop.27:out* bprop.27:dout2 -bprop.27:[CNode]29:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH -#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& -S-Prim-MakeTuple:4S-Prim-MakeTupleh \ No newline at end of file +bprop.27:[CNode]29:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb& +S-Prim-MakeTuple:4S-Prim-MakeTuplebH +#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir index 2d73d38b408..88c1013b264 100644 Binary files a/mindspore/python/mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir and b/mindspore/python/mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir differ diff --git a/mindspore/python/mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir new file mode 100644 index 00000000000..b72cd1f7892 Binary files /dev/null and b/mindspore/python/mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir differ diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Select_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Select_bprop.mindir index dde240d7e6e..7c2538177a8 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/Select_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/Select_bprop.mindir @@ -27,9 +27,9 @@ bprop.30:x* bprop.30:y* bprop.30:out* bprop.30:dout2 -bprop.30:[CNode]36:8:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH -#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& -S-Prim-MakeTuple:9S-Prim-MakeTuplebv +bprop.30:[CNode]36:8:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pbv S-Prim-Select:5 S-Prim-Select output_names€Š Zoutput€3 - input_names€ŠZ condition€ŠZx€ŠZy€h \ No newline at end of file + input_names€ŠZ condition€ŠZx€ŠZy€bH +#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& +S-Prim-MakeTuple:9S-Prim-MakeTupleh \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/Shape_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/Shape_bprop.mindir index 9331062fdfc..24ef85753c2 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/Shape_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/Shape_bprop.mindir @@ -9,6 +9,6 @@ z bprop.21:x* bprop.21:out* bprop.21:dout2 -bprop.21:[CNode]23:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb& +bprop.21:[CNode]23:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb& S-Prim-MakeTuple:4S-Prim-MakeTuplebH #S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir b/mindspore/python/mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir index 517e2e98fa6..dc69fdb053a 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +++ b/mindspore/python/mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir @@ -7,6 +7,6 @@ v bprop.9:x* bprop.9:out* bprop.9:dout2 -bprop.9:[CNode]11:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH -#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b& -S-Prim-MakeTuple:4S-Prim-MakeTupleh \ No newline at end of file +bprop.9:[CNode]11:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb& +S-Prim-MakeTuple:4S-Prim-MakeTuplebH +#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/bprop_mindir/generate_mindir.py b/mindspore/python/mindspore/ops/bprop_mindir/generate_mindir.py index d24b00ae13b..397b4b1c207 100644 --- a/mindspore/python/mindspore/ops/bprop_mindir/generate_mindir.py +++ b/mindspore/python/mindspore/ops/bprop_mindir/generate_mindir.py @@ -29,9 +29,9 @@ serializable_bprop_ops = [P.ReLU, P.Identity, inner.Range, P.OnesLike, P.ZerosLi P.LinSpace, P.DropoutGenMask, P.OneHot, P.Assign, P.IOU, P.BNTrainingReduce, P.Equal, P.NotEqual, P.Greater, P.GreaterEqual, P.Less, P.LessEqual, P.LogicalAnd, P.LogicalOr, P.ReduceAll, P.ReduceAny, P.DropoutDoMask, P.DType, P.Shape, P.DynamicShape, P.Rank, - P.Select, P.ScatterMax, G.ReluGrad, _constants.kTupleGetItem, P.FloorDiv, P.TruncateDiv, - P.Minimum, P.Maximum, P.IsNan, P.IsInf, P.ReLUV2, "Depend", "stop_gradient", "Switch", - "UpdateState", "Load"] + P.Select, P.ScatterMax, P.ScatterMin, G.ReluGrad, _constants.kTupleGetItem, P.FloorDiv, + P.TruncateDiv, P.Minimum, P.Maximum, P.IsNan, P.IsInf, P.ReLUV2, "Depend", "stop_gradient", + "Switch", "UpdateState", "Load"] def run_generate(): diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 30c6341c2c9..104cb71b367 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -22,7 +22,8 @@ A collection of function to build neural networks or to compute functions. from . import array_func, parameter_func, math_func from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank, reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array, - expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill) + expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill, + scatter_min) from .parameter_func import assign, assign_add, assign_sub, index_add from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le, tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div, diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 33e34eac37e..cc70ba999af 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -587,6 +587,49 @@ def transpose(input_x, input_perm): return transpose_(input_x, input_perm) +def scatter_min(input_x, indices, updates, use_locking=False): + r""" + Using given values to update tensor value through the min operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Args: + - **input_x** (Parameter) - The target tensor, with data type of Parameter. + The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. + - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32. + - **updates** (Tensor) - The tensor doing the min operation with `input_x`, + the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`. + - use_locking (bool): Whether to protect the assignment by a lock. Default: False. + + Outputs: + Tensor, the updated `input_x`, has the same shape and type as `input_x`. + + Raises: + TypeError: If `indices` is not an int32. + TypeError: If `use_locking` is not a bool. + RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter + is required when data type conversion of Parameter is not supported. + ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor, Parameter + >>> from mindspore import ops + >>> input_x = Parameter(Tensor(np.zeros((2, 3)), mindspore.float32), name="input_x") + >>> indices = Tensor(np.array([1, 0]), mindspore.int32) + >>> update = Tensor(np.arange(6).reshape((2, 3)), mindspore.float32) + >>> scatter_min = ops.ScatterMin() + >>> output = scatter_min(input_x, indices, update) + >>> print(output) + [[0. 0. 0.] + [0. 0. 0.]] + """ + return P.ScatterMin(use_locking)(input_x, indices, updates) + + scatter_nd_ = P.ScatterNd() def scatter_nd(indices, updates, shape): r""" @@ -1040,6 +1083,7 @@ __all__ = [ 'gather', 'gather_d', 'gather_nd', - 'masked_fill' + 'masked_fill', + 'scatter_min' ] __all__.sort() diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index ec33959431d..bee097156b0 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -980,6 +980,7 @@ tensor_operator_registry.register('logical_not', P.LogicalNot) tensor_operator_registry.register('sum', P.ReduceSum) tensor_operator_registry.register('split', P.Split) tensor_operator_registry.register('index_add', P.IndexAdd) +tensor_operator_registry.register('scatter_min', P.ScatterMin) # ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index df01537fd60..dbd3e69f717 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -4301,7 +4301,7 @@ class ScatterMin(_ScatterOp): is required when data type conversion of Parameter is not supported. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32), diff --git a/tests/st/ops/gpu/test_scatter_func_op.py b/tests/st/ops/gpu/test_scatter_func_op.py index 6090b934752..9e76d028406 100644 --- a/tests/st/ops/gpu/test_scatter_func_op.py +++ b/tests/st/ops/gpu/test_scatter_func_op.py @@ -29,6 +29,7 @@ func_map = { "add": P.ScatterAdd, "sub": P.ScatterSub, "max": P.ScatterMax, + "min": P.ScatterMin, } @@ -130,6 +131,11 @@ def test_scatter_func_small_float32(): expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -164,6 +170,12 @@ def test_scatter_func_input_updated(): expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + # min + net = TestScatterFuncNet("min", lock, inputx, indices, updates) + net() + expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -329,6 +341,12 @@ def test_scatter_func_large_shape_float32(): ) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = np.ones((4, 2, 3, 4)).astype(np.float32) + expected[0][0][0][0] = 0.0 + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -358,6 +376,11 @@ def test_scatter_func_small_float32_use_locking_false(): expected = np.array([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_use_locking_false_net("min", inputx, indices, updates) + expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -418,6 +441,11 @@ def test_scatter_func_input_less_than_1_float32(): ) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = inputx.asnumpy() + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -447,6 +475,11 @@ def test_scatter_func_float16(): expected = np.array([[6.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -496,6 +529,11 @@ def test_scatter_func_large_float16(): ]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = np.zeros((2, 3, 4)).astype(np.float16) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -541,6 +579,11 @@ def test_scatter_func_disordered_float16(): ) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16)) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -590,6 +633,11 @@ def test_scatter_func_large_int32(): ]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = np.zeros((2, 3, 4)).astype(np.int32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -635,6 +683,11 @@ def test_scatter_func_disordered_int32(): ) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + # min + output = scatter_func_net("min", inputx, indices, updates) + expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training diff --git a/tests/ut/python/optimizer/test_bprop_mindir.py b/tests/ut/python/optimizer/test_bprop_mindir.py index 640d7e8918f..197afd9311d 100644 --- a/tests/ut/python/optimizer/test_bprop_mindir.py +++ b/tests/ut/python/optimizer/test_bprop_mindir.py @@ -385,6 +385,30 @@ def test_scatter_max(): grad.compile(indices, updates) +def test_scatter_min(): + """ + Feature: Bprop pre-compilation. + Description: Compile the backward graph for the scatter_min op. + Expectation: Load the bprop mindir successfully. + """ + + class ScatterMinNet(nn.Cell): + def __init__(self): + super(ScatterMinNet, self).__init__() + self.scatter_min = P.ScatterMin() + self.input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mstype.float32), + name="input_x") + + def construct(self, indices, updates): + return self.scatter_min(self.input_x, indices, updates) + + indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) + updates = Tensor(np.ones([2, 2, 3]) * 88, mstype.float32) + scatter_min = ScatterMinNet() + grad = GradNet(scatter_min) + grad.compile(indices, updates) + + def test_relu_grad(): """ Feature: Bprop pre-compilation.