SactterMin operator supports GPU, and add function interface
This commit is contained in:
parent
fcabbe95b3
commit
b8ec35c3b7
|
@ -305,5 +305,77 @@ MS_REG_GPU_KERNEL_TWO(ScatterMax,
|
|||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
ScatterFunctorKernelMod, uint8_t, int64_t)
|
||||
|
||||
// ScatterMin
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ScatterFunctorKernelMod, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ScatterFunctorKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
ScatterFunctorKernelMod, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
ScatterFunctorKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
ScatterFunctorKernelMod, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
ScatterFunctorKernelMod, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
ScatterFunctorKernelMod, int8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
ScatterFunctorKernelMod, int8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
ScatterFunctorKernelMod, uint8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
ScatterFunctorKernelMod, uint8_t, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,10 +28,8 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
|
||||
static const std::map<std::string, ScatterFunctorType> kScatterFunctorTypeMap = {
|
||||
{"ScatterUpdate", SCATTER_FUNC_UPDATE},
|
||||
{"ScatterAdd", SCATTER_FUNC_ADD},
|
||||
{"ScatterSub", SCATTER_FUNC_SUB},
|
||||
{"ScatterMax", SCATTER_FUNC_MAX},
|
||||
{"ScatterUpdate", SCATTER_FUNC_UPDATE}, {"ScatterAdd", SCATTER_FUNC_ADD}, {"ScatterSub", SCATTER_FUNC_SUB},
|
||||
{"ScatterMax", SCATTER_FUNC_MAX}, {"ScatterMin", SCATTER_FUNC_MIN},
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
|
@ -61,7 +59,7 @@ class ScatterFunctorKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
auto iter = kScatterFunctorTypeMap.find(kernel_name);
|
||||
if (iter == kScatterFunctorTypeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "Only support these scatter functors: ScatterUpdate, ScatterAdd, "
|
||||
<< "ScatterSub or ScatterMax currently, but got " << kernel_name;
|
||||
<< "ScatterSub, ScatterMax or ScatterMin currently, but got " << kernel_name;
|
||||
} else {
|
||||
scatter_functor_type_ = iter->second;
|
||||
}
|
||||
|
|
|
@ -61,6 +61,17 @@ __global__ void ScatterMaxKernel(const size_t inner_size, const size_t updates_s
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void ScatterMinKernel(const size_t inner_size, const size_t updates_size, const S *indices, const T *updates,
|
||||
T *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
input[current_pos] = updates[pos] < input[current_pos] ? updates[pos] : input[current_pos];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void ScatterFunc(enum ScatterFunctorType func_type, const size_t &inner_size, const size_t &indices_size,
|
||||
const S *indices, const T *updates, T *input, cudaStream_t cuda_stream) {
|
||||
|
@ -78,6 +89,9 @@ void ScatterFunc(enum ScatterFunctorType func_type, const size_t &inner_size, co
|
|||
case SCATTER_FUNC_MAX:
|
||||
return ScatterMaxKernel<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size,
|
||||
indices, updates, input);
|
||||
case SCATTER_FUNC_MIN:
|
||||
return ScatterMinKernel<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size,
|
||||
indices, updates, input);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ enum ScatterFunctorType {
|
|||
SCATTER_FUNC_ADD,
|
||||
SCATTER_FUNC_SUB,
|
||||
SCATTER_FUNC_MAX,
|
||||
SCATTER_FUNC_MIN,
|
||||
SCATTER_FUNC_INVALID_TYPE = 255
|
||||
};
|
||||
|
||||
|
|
|
@ -859,6 +859,17 @@ def get_bprop_scatter_max(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ScatterMin)
|
||||
def get_bprop_scatter_min(self):
|
||||
"""Generate bprop for ScatterMin"""
|
||||
gather = P.Gather()
|
||||
|
||||
def bprop(x, indices, update, out, dout):
|
||||
return dout, zeros_like(indices), gather(dout, indices, 0)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Argmax)
|
||||
def get_bprop_argmax(self):
|
||||
"""Generate bprop for Argmax"""
|
||||
|
|
|
@ -9,6 +9,6 @@ y
|
|||
bprop.12:x*
|
||||
bprop.12:out*
|
||||
bprop.12:dout2
|
||||
bprop.12:[CNode]14:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
|
||||
bprop.12:[CNode]14:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.15:x*
|
||||
bprop.15:out*
|
||||
bprop.15:dout2
|
||||
bprop.15:[CNode]17:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
|
||||
bprop.15:[CNode]17:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.18:x*
|
||||
bprop.18:out*
|
||||
bprop.18:dout2
|
||||
bprop.18:[CNode]20:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
||||
bprop.18:[CNode]20:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.24:x*
|
||||
bprop.24:out*
|
||||
bprop.24:dout2
|
||||
bprop.24:[CNode]26:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
|
||||
bprop.24:[CNode]26:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -5,5 +5,5 @@ m
|
|||
bprop.1:x*
|
||||
bprop.1:out*
|
||||
bprop.1:dout2
|
||||
bprop.1:[CNode]2:1:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
|
||||
bprop.1:[CNode]2:1:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:2S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ s
|
|||
bprop.6:x*
|
||||
bprop.6:out*
|
||||
bprop.6:dout2
|
||||
bprop.6:[CNode]8:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.6:[CNode]8:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ s
|
|||
bprop.3:x*
|
||||
bprop.3:out*
|
||||
bprop.3:dout2
|
||||
bprop.3:[CNode]5:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.3:[CNode]5:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938PbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.27:x*
|
||||
bprop.27:out*
|
||||
bprop.27:dout2
|
||||
bprop.27:[CNode]29:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
||||
bprop.27:[CNode]29:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
Binary file not shown.
Binary file not shown.
|
@ -27,9 +27,9 @@ bprop.30:x*
|
|||
bprop.30:y*
|
||||
bprop.30:out*
|
||||
bprop.30:dout2
|
||||
bprop.30:[CNode]36:8:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:9S-Prim-MakeTuplebv
|
||||
bprop.30:[CNode]36:8:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pbv
|
||||
S-Prim-Select:5
S-Prim-Select
|
||||
output_names€ŠZoutput€3
|
||||
input_names€ŠZ condition€ŠZx€ŠZy€h
|
||||
input_names€ŠZ condition€ŠZx€ŠZy€bH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:9S-Prim-MakeTupleh
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.21:x*
|
||||
bprop.21:out*
|
||||
bprop.21:dout2
|
||||
bprop.21:[CNode]23:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPb&
|
||||
bprop.21:[CNode]23:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -7,6 +7,6 @@ v
|
|||
bprop.9:x*
|
||||
bprop.9:out*
|
||||
bprop.9:dout2
|
||||
bprop.9:[CNode]11:3:@056bff1e3c57347dea806dc7c1b2b798b13c876c828efa18a94acc87833ba12dPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
||||
bprop.9:[CNode]11:3:@95457f6f5c75edb42385b9b7906aa1c6ab0f9d7116b7f1fe413ac74a65097938Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -29,9 +29,9 @@ serializable_bprop_ops = [P.ReLU, P.Identity, inner.Range, P.OnesLike, P.ZerosLi
|
|||
P.LinSpace, P.DropoutGenMask, P.OneHot, P.Assign, P.IOU, P.BNTrainingReduce, P.Equal,
|
||||
P.NotEqual, P.Greater, P.GreaterEqual, P.Less, P.LessEqual, P.LogicalAnd, P.LogicalOr,
|
||||
P.ReduceAll, P.ReduceAny, P.DropoutDoMask, P.DType, P.Shape, P.DynamicShape, P.Rank,
|
||||
P.Select, P.ScatterMax, G.ReluGrad, _constants.kTupleGetItem, P.FloorDiv, P.TruncateDiv,
|
||||
P.Minimum, P.Maximum, P.IsNan, P.IsInf, P.ReLUV2, "Depend", "stop_gradient", "Switch",
|
||||
"UpdateState", "Load"]
|
||||
P.Select, P.ScatterMax, P.ScatterMin, G.ReluGrad, _constants.kTupleGetItem, P.FloorDiv,
|
||||
P.TruncateDiv, P.Minimum, P.Maximum, P.IsNan, P.IsInf, P.ReLUV2, "Depend", "stop_gradient",
|
||||
"Switch", "UpdateState", "Load"]
|
||||
|
||||
|
||||
def run_generate():
|
||||
|
|
|
@ -22,7 +22,8 @@ A collection of function to build neural networks or to compute functions.
|
|||
from . import array_func, parameter_func, math_func
|
||||
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
|
||||
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
|
||||
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill)
|
||||
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill,
|
||||
scatter_min)
|
||||
from .parameter_func import assign, assign_add, assign_sub, index_add
|
||||
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
|
||||
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,
|
||||
|
|
|
@ -587,6 +587,49 @@ def transpose(input_x, input_perm):
|
|||
return transpose_(input_x, input_perm)
|
||||
|
||||
|
||||
def scatter_min(input_x, indices, updates, use_locking=False):
|
||||
r"""
|
||||
Using given values to update tensor value through the min operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
Args:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
|
||||
- use_locking (bool): Whether to protect the assignment by a lock. Default: False.
|
||||
|
||||
Outputs:
|
||||
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `indices` is not an int32.
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, Parameter
|
||||
>>> from mindspore import ops
|
||||
>>> input_x = Parameter(Tensor(np.zeros((2, 3)), mindspore.float32), name="input_x")
|
||||
>>> indices = Tensor(np.array([1, 0]), mindspore.int32)
|
||||
>>> update = Tensor(np.arange(6).reshape((2, 3)), mindspore.float32)
|
||||
>>> scatter_min = ops.ScatterMin()
|
||||
>>> output = scatter_min(input_x, indices, update)
|
||||
>>> print(output)
|
||||
[[0. 0. 0.]
|
||||
[0. 0. 0.]]
|
||||
"""
|
||||
return P.ScatterMin(use_locking)(input_x, indices, updates)
|
||||
|
||||
|
||||
scatter_nd_ = P.ScatterNd()
|
||||
def scatter_nd(indices, updates, shape):
|
||||
r"""
|
||||
|
@ -1040,6 +1083,7 @@ __all__ = [
|
|||
'gather',
|
||||
'gather_d',
|
||||
'gather_nd',
|
||||
'masked_fill'
|
||||
'masked_fill',
|
||||
'scatter_min'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -980,6 +980,7 @@ tensor_operator_registry.register('logical_not', P.LogicalNot)
|
|||
tensor_operator_registry.register('sum', P.ReduceSum)
|
||||
tensor_operator_registry.register('split', P.Split)
|
||||
tensor_operator_registry.register('index_add', P.IndexAdd)
|
||||
tensor_operator_registry.register('scatter_min', P.ScatterMin)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -4301,7 +4301,7 @@ class ScatterMin(_ScatterOp):
|
|||
is required when data type conversion of Parameter is not supported.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32),
|
||||
|
|
|
@ -29,6 +29,7 @@ func_map = {
|
|||
"add": P.ScatterAdd,
|
||||
"sub": P.ScatterSub,
|
||||
"max": P.ScatterMax,
|
||||
"min": P.ScatterMin,
|
||||
}
|
||||
|
||||
|
||||
|
@ -130,6 +131,11 @@ def test_scatter_func_small_float32():
|
|||
expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -164,6 +170,12 @@ def test_scatter_func_input_updated():
|
|||
expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
|
||||
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
net = TestScatterFuncNet("min", lock, inputx, indices, updates)
|
||||
net()
|
||||
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -329,6 +341,12 @@ def test_scatter_func_large_shape_float32():
|
|||
)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = np.ones((4, 2, 3, 4)).astype(np.float32)
|
||||
expected[0][0][0][0] = 0.0
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -358,6 +376,11 @@ def test_scatter_func_small_float32_use_locking_false():
|
|||
expected = np.array([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_use_locking_false_net("min", inputx, indices, updates)
|
||||
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -418,6 +441,11 @@ def test_scatter_func_input_less_than_1_float32():
|
|||
)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = inputx.asnumpy()
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -447,6 +475,11 @@ def test_scatter_func_float16():
|
|||
expected = np.array([[6.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -496,6 +529,11 @@ def test_scatter_func_large_float16():
|
|||
])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = np.zeros((2, 3, 4)).astype(np.float16)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -541,6 +579,11 @@ def test_scatter_func_disordered_float16():
|
|||
)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16))
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -590,6 +633,11 @@ def test_scatter_func_large_int32():
|
|||
])
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = np.zeros((2, 3, 4)).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -635,6 +683,11 @@ def test_scatter_func_disordered_int32():
|
|||
)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
|
@ -385,6 +385,30 @@ def test_scatter_max():
|
|||
grad.compile(indices, updates)
|
||||
|
||||
|
||||
def test_scatter_min():
|
||||
"""
|
||||
Feature: Bprop pre-compilation.
|
||||
Description: Compile the backward graph for the scatter_min op.
|
||||
Expectation: Load the bprop mindir successfully.
|
||||
"""
|
||||
|
||||
class ScatterMinNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterMinNet, self).__init__()
|
||||
self.scatter_min = P.ScatterMin()
|
||||
self.input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mstype.float32),
|
||||
name="input_x")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
return self.scatter_min(self.input_x, indices, updates)
|
||||
|
||||
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
|
||||
updates = Tensor(np.ones([2, 2, 3]) * 88, mstype.float32)
|
||||
scatter_min = ScatterMinNet()
|
||||
grad = GradNet(scatter_min)
|
||||
grad.compile(indices, updates)
|
||||
|
||||
|
||||
def test_relu_grad():
|
||||
"""
|
||||
Feature: Bprop pre-compilation.
|
||||
|
|
Loading…
Reference in New Issue