forked from mindspore-Ecosystem/mindspore
!34131 scatter_min: add dynamic_shape case for ascend
Merge pull request !34131 from hujiahui8/scatter_min
This commit is contained in:
commit
0968b28871
|
@ -22,7 +22,7 @@ mindspore.ops.ScatterAdd
|
|||
**输入:**
|
||||
|
||||
- **input_x** (Parameter) - ScatterAdd的输入,任意维度的Parameter。
|
||||
- **indices** (Tensor) - 指定相加操作的索引,数据类型为mindspore.int32。
|
||||
- **indices** (Tensor) - 指定相加操作的索引,数据类型为mindspore.int32或者mindspore.int64。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 相加操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。
|
||||
|
||||
**输出:**
|
||||
|
@ -32,6 +32,6 @@ mindspore.ops.ScatterAdd
|
|||
**异常:**
|
||||
|
||||
- **TypeError** - `use_locking` 不是bool。
|
||||
- **TypeError** - `indices` 不是int32。
|
||||
- **TypeError** - `indices` 不是int32或者int64。
|
||||
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。
|
||||
- **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。
|
||||
|
|
|
@ -20,7 +20,7 @@ mindspore.ops.ScatterMax
|
|||
**输入:**
|
||||
|
||||
- **input_x** (Parameter)- ScatterMax的输入,任意维度的Parameter。
|
||||
- **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32。
|
||||
- **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 取最大值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。
|
||||
|
||||
|
||||
|
@ -31,6 +31,6 @@ mindspore.ops.ScatterMax
|
|||
**异常:**
|
||||
|
||||
- **TypeError** - `use_locking` 不是bool。
|
||||
- **TypeError** - `indices` 不是int32。
|
||||
- **TypeError** - `indices` 不是int32或者int64。
|
||||
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。
|
||||
- **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。
|
|
@ -11,7 +11,7 @@ mindspore.ops.ScatterMin
|
|||
\text{input_x}[\text{indices}[i, ..., j], :]
|
||||
= min(\text{input_x}[\text{indices}[i, ..., j], :], \text{updates}[i, ..., j, :])
|
||||
|
||||
输入的 `input_x` 和 `updates` 遵循隐式类型转换规则,以确保数据类型一致。如果数据类型不同,则低精度数据类型将转换为高精度的数据类型。当参数的数据类型需要转换时,则会抛出RuntimeError异常。
|
||||
输入的 `input_x` 和 `updates` 遵循隐式类型转换规则,以确保数据类型一致。如果数据类型不同,则低精度数据类型将转换为高精度的数据类型。当`updates` 不支持转成 `input_x` 需要的数据类型时,则会抛出RuntimeError异常。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
@ -20,7 +20,7 @@ mindspore.ops.ScatterMin
|
|||
**输入:**
|
||||
|
||||
- **input_x** (Parameter) - ScatterMin的输入,任意维度的Parameter。
|
||||
- **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32。
|
||||
- **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 取最小值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。
|
||||
|
||||
**输出:**
|
||||
|
@ -30,6 +30,6 @@ mindspore.ops.ScatterMin
|
|||
**异常:**
|
||||
|
||||
- **TypeError** - `use_locking` 不是bool。
|
||||
- **TypeError** - `indices` 不是int32。
|
||||
- **TypeError** - `indices` 不是int32或者int64。
|
||||
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。
|
||||
- **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。
|
|
@ -19,7 +19,7 @@
|
|||
**输入:**
|
||||
|
||||
- **input_x** (Parameter) - ScatterSub的输入,任意维度的Parameter。
|
||||
- **indices** (Tensor) - 指定相减操作的索引,其数据类型必须为mindspore.int32。
|
||||
- **indices** (Tensor) - 指定相减操作的索引,其数据类型必须为mindspore.int32或者mindspore.int64。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 相减的Tensor,其数据类型与 `input_x` 的相同,shape为 `indices_shape + x_shape[1:]` 。
|
||||
|
||||
**输出:**
|
||||
|
@ -29,6 +29,6 @@
|
|||
**异常:**
|
||||
|
||||
- **TypeError** - `use_locking` 不是bool。
|
||||
- **TypeError** - `indices` 不是int32。
|
||||
- **TypeError** - `indices` 不是int32或者int64。
|
||||
- **ValueError** - `updates` 的shape不是 `indices_shape + x_shape[1:]` 。
|
||||
- **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。
|
|
@ -19,7 +19,7 @@
|
|||
**输入:**
|
||||
|
||||
- **input_x** (Parameter) - ScatterUpdate的输入,任意维度的Parameter。
|
||||
- **indices** (Tensor) - 指定更新操作的索引。数据类型为int32。如果索引中存在重复项,则更新的顺序无法得知。
|
||||
- **indices** (Tensor) - 指定更新操作的索引。数据类型为int32或者int64。如果索引中存在重复项,则更新的顺序无法得知。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 更新操作的Tensor,其数据类型与 `input_x` 相同,shape为 `indices.shape + input_x.shape[1:]` 。
|
||||
|
||||
**输出:**
|
||||
|
@ -29,5 +29,5 @@
|
|||
**异常:**
|
||||
|
||||
- **TypeError** - `use_locking` 不是bool。
|
||||
- **TypeError** - `indices` 不是int32。
|
||||
- **TypeError** - `indices` 不是int32或者int64。
|
||||
- **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。
|
|
@ -16,7 +16,7 @@ mindspore.ops.scatter_max
|
|||
**参数:**
|
||||
|
||||
- **input_x** (Parameter)- ScatterMax的输入,任意维度的Parameter。
|
||||
- **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32。
|
||||
- **indices** (Tensor) - 指定最大值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 取最大值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + x.shape[1:]` 。
|
||||
|
||||
**输出:**
|
||||
|
@ -26,6 +26,6 @@ mindspore.ops.scatter_max
|
|||
**异常:**
|
||||
|
||||
- **TypeError** - `use_locking` 不是bool。
|
||||
- **TypeError** - `indices` 不是int32。
|
||||
- **TypeError** - `indices` 不是int32或者int64。
|
||||
- **ValueError** - `updates` 的shape不等于 `indices.shape + x.shape[1:]` 。
|
||||
- **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。
|
|
@ -16,7 +16,7 @@ mindspore.ops.scatter_min
|
|||
**输入:**
|
||||
|
||||
- **input_x** (Parameter) - scatter_min的输入,任意维度的Parameter。
|
||||
- **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32。
|
||||
- **indices** (Tensor) - 指定最小值操作的索引,数据类型必须为mindspore.int32或者mindspore.int64。
|
||||
- **updates** (Tensor) - 指定与 `input_x` 取最小值操作的Tensor,数据类型与 `input_x` 相同,shape为 `indices.shape + input_x.shape[1:]` 。
|
||||
|
||||
**输出:**
|
||||
|
@ -25,6 +25,6 @@ mindspore.ops.scatter_min
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `indices` 不是int32。
|
||||
- **TypeError** - `indices` 不是int32或者int64。
|
||||
- **ValueError** - `updates` 的shape不等于 `indices.shape + input_x.shape[1:]` 。
|
||||
- **RuntimeError** - 当 `input_x` 和 `updates` 类型不一致,需要进行类型转换时,如果 `updates` 不支持转成参数 `input_x` 需要的数据类型,就会报错。
|
|
@ -307,6 +307,20 @@ MS_REG_GPU_KERNEL_TWO(ScatterMax,
|
|||
ScatterFunctorKernelMod, uint8_t, int64_t)
|
||||
|
||||
// ScatterMin
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
ScatterFunctorKernelMod, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
ScatterFunctorKernelMod, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -335,6 +349,20 @@ MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
|||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
ScatterFunctorKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
ScatterFunctorKernelMod, int64_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
ScatterFunctorKernelMod, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
@ -349,33 +377,5 @@ MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
|||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
ScatterFunctorKernelMod, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
ScatterFunctorKernelMod, int8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
ScatterFunctorKernelMod, int8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
ScatterFunctorKernelMod, uint8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(ScatterMin,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
ScatterFunctorKernelMod, uint8_t, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -128,6 +128,14 @@ template CUDA_LIB_EXPORT void ScatterFunc<half, int64_t>(enum ScatterFunctorType
|
|||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const half *updates, half *input,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ScatterFunc<double, int>(enum ScatterFunctorType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const double *updates, double *input,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ScatterFunc<double, int64_t>(enum ScatterFunctorType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const double *updates, double *input,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ScatterFunc<int, int>(enum ScatterFunctorType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const int *updates, int *input,
|
||||
|
@ -136,6 +144,14 @@ template CUDA_LIB_EXPORT void ScatterFunc<int, int64_t>(enum ScatterFunctorType
|
|||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const int *updates, int *input,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ScatterFunc<int64_t, int>(enum ScatterFunctorType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const int64_t *updates, int64_t *input,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ScatterFunc<int64_t, int64_t>(enum ScatterFunctorType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const int64_t *updates,
|
||||
int64_t *input, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ScatterFunc<unsigned char, int>(enum ScatterFunctorType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const unsigned char *updates,
|
||||
|
|
|
@ -40,7 +40,7 @@ abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std
|
|||
}
|
||||
|
||||
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
return input_x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
|
@ -66,12 +66,11 @@ TypePtr ScatterMinInferType(const PrimitivePtr &primitive, const std::vector<Abs
|
|||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
|
||||
auto prim_name = primitive->name();
|
||||
std::set<TypePtr> type_set = {kInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex,
|
||||
prim_name);
|
||||
const std::set<TypePtr> indices_types = {kInt32, kInt64};
|
||||
const std::set<TypePtr> valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, indices_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, valid_types, prim_name);
|
||||
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_x_type_ptr);
|
||||
|
|
|
@ -23,14 +23,19 @@ scatter_min_op_info = TBERegOp("ScatterMin") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("scatter_min") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_compile_static(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("use_locking", "optional", "bool", "all") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "indices", False, "required", "all") \
|
||||
.input(2, "updates", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -769,7 +769,7 @@ def scatter_min(input_x, indices, updates):
|
|||
Using given values to update tensor value through the min operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
for each `i, ..., j` in `indices.shape`:
|
||||
for each :math:`i, ..., j` in `indices.shape`:
|
||||
|
||||
.. math::
|
||||
|
||||
|
@ -781,21 +781,20 @@ def scatter_min(input_x, indices, updates):
|
|||
the relatively highest priority data type.
|
||||
|
||||
Args:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
input_x (Parameter): The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
|
||||
indices (Tensor): The index to do min operation whose data type must be mindspore.int32 or mindspore.int64.
|
||||
updates (Tensor): The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `indices` is not an int32.
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `indices` is not an int32 or an int64.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
|
@ -109,7 +109,7 @@ class _ScatterOpDynamic(PrimitiveWithCheck):
|
|||
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
|
||||
|
||||
def check_dtype(self, x_dtype, indices_dtype, updates_dtype):
|
||||
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
|
||||
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
|
||||
args = {"x": x_dtype, "updates": updates_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
|
||||
|
@ -4160,7 +4160,7 @@ class ScatterUpdate(_ScatterOpDynamic):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index of input tensor. With int32 data type.
|
||||
- **indices** (Tensor) - The index of input tensor. With int32 or int64 data type.
|
||||
If there are duplicates in indices, the order for updating is undefined.
|
||||
- **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and updates.shape = indices.shape + input_x.shape[1:].
|
||||
|
@ -4170,7 +4170,7 @@ class ScatterUpdate(_ScatterOpDynamic):
|
|||
|
||||
Raises:
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `indices` is not an int32.
|
||||
TypeError: If `indices` is not an int32 or an int64.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
|
||||
|
@ -4289,7 +4289,8 @@ class ScatterMax(_ScatterOpDynamic):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32.
|
||||
- **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32 or
|
||||
mindspore.int64.
|
||||
- **updates** (Tensor) - The tensor that performs the maximum operation with `input_x`,
|
||||
the data type is the same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
|
||||
|
||||
|
@ -4298,7 +4299,7 @@ class ScatterMax(_ScatterOpDynamic):
|
|||
|
||||
Raises:
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `indices` is not an int32.
|
||||
TypeError: If `indices` is not an int32 or an int64.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
|
@ -4326,7 +4327,7 @@ class ScatterMin(_ScatterOpDynamic):
|
|||
Using given values to update tensor value through the min operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
for each `i, ..., j` in `indices.shape`:
|
||||
for each :math:`i, ..., j` in `indices.shape`:
|
||||
|
||||
.. math::
|
||||
|
||||
|
@ -4343,17 +4344,18 @@ class ScatterMin(_ScatterOpDynamic):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or
|
||||
mindspore.int64.
|
||||
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
|
||||
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `indices` is not an int32.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
|
||||
TypeError: If `indices` is not an int32 or an int64.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
|
||||
|
@ -4401,7 +4403,8 @@ class ScatterAdd(_ScatterOpDynamic):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or
|
||||
mindspore.int64.
|
||||
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
|
||||
|
||||
|
@ -4410,7 +4413,7 @@ class ScatterAdd(_ScatterOpDynamic):
|
|||
|
||||
Raises:
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `indices` is not an int32.
|
||||
TypeError: If `indices` is not an int32 or an int64.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
|
@ -4511,7 +4514,8 @@ class ScatterSub(_ScatterOpDynamic):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32 or
|
||||
mindspore.int64.
|
||||
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
|
||||
|
||||
|
@ -4520,7 +4524,7 @@ class ScatterSub(_ScatterOpDynamic):
|
|||
|
||||
Raises:
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `indices` is not an int32.
|
||||
TypeError: If `indices` is not an int32 or an int64.
|
||||
ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
|
|
|
@ -0,0 +1,250 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
|
||||
# all cases tested against dchip
|
||||
|
||||
|
||||
class TestScatterMinNet(nn.Cell):
|
||||
def __init__(self, inputx):
|
||||
super(TestScatterMinNet, self).__init__()
|
||||
|
||||
self.scatter_min = ops.ScatterMin()
|
||||
self.inputx = Parameter(inputx, name="inputx")
|
||||
|
||||
def construct(self, indices, updates):
|
||||
out = self.scatter_min(self.inputx, indices, updates)
|
||||
return out
|
||||
|
||||
|
||||
def scatter_min_forward(nptype):
|
||||
inputx = Tensor(np.arange(0, 9).reshape((3, 3)).astype(nptype))
|
||||
indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
|
||||
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(nptype))
|
||||
|
||||
net = TestScatterMinNet(inputx)
|
||||
output = net(indices, updates)
|
||||
expected = inputx.asnumpy()
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
def scatter_min_dynamic_updates():
|
||||
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
|
||||
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
|
||||
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
|
||||
updates_dy = Tensor(shape=(2, 2, 2, None, 4), dtype=mindspore.float32)
|
||||
|
||||
net = TestScatterMinNet(inputx)
|
||||
net.set_inputs(indices, updates_dy)
|
||||
output = net(indices, updates)
|
||||
expected = np.ones((4, 2, 3, 4)).astype(np.float32)
|
||||
expected[0][0][0][0] = 0.0
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
def scatter_min_dynamic_indices():
|
||||
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.int32))
|
||||
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
|
||||
indices_dy = Tensor(shape=(2, None), dtype=mindspore.int32)
|
||||
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.int32))
|
||||
|
||||
net = TestScatterMinNet(inputx)
|
||||
net.set_inputs(indices_dy, updates)
|
||||
output = net(indices, updates)
|
||||
expected = np.ones((4, 2, 3, 4)).astype(np.int32)
|
||||
expected[0][0][0][0] = 0
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
class TestScatterMinGradNet(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(TestScatterMinGradNet, self).__init__()
|
||||
self.grad = ops.GradOperation(get_all=True, sens_param=True, get_by_list=True)
|
||||
self.network = network
|
||||
self.params = ParameterTuple(network.trainable_params())
|
||||
|
||||
def construct(self, indices, updates, dout):
|
||||
out = self.grad(self.network, self.params)(indices, updates, dout)
|
||||
return out
|
||||
|
||||
|
||||
def scatter_min_grad(nptype):
|
||||
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(nptype)))
|
||||
indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
|
||||
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(nptype))
|
||||
dout = Tensor(np.flip(np.arange(0, 12).reshape((3, 4)).astype(nptype)))
|
||||
|
||||
net = TestScatterMinGradNet(TestScatterMinNet(inputx))
|
||||
output = net(indices, updates, dout)
|
||||
indices_grad = output[0][0]
|
||||
updates_grad = output[0][1]
|
||||
|
||||
indices_expected = np.array([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]).astype(nptype)
|
||||
updates_expected = np.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]
|
||||
],
|
||||
[
|
||||
[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[11, 10, 9, 8], [11, 10, 9, 8], [11, 10, 9, 8]
|
||||
],
|
||||
[
|
||||
[3, 2, 1, 0], [3, 2, 1, 0], [3, 2, 1, 0]
|
||||
]
|
||||
]
|
||||
]).astype(nptype)
|
||||
np.testing.assert_array_almost_equal(indices_grad, indices_expected)
|
||||
np.testing.assert_array_almost_equal(updates_grad, updates_expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_forward_float16():
|
||||
"""
|
||||
Feature: test scatter_min forward.
|
||||
Description: test float16 inputs.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_forward(np.float16)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_forward(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_forward_float32():
|
||||
"""
|
||||
Feature: test scatter_min forward.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_forward(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_forward(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_forward_int32():
|
||||
"""
|
||||
Feature: test scatter_min forward.
|
||||
Description: test int32 inputs.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_forward(np.int32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_forward(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_dynamic_indices():
|
||||
"""
|
||||
Feature: test scatter_min dynamic shape.
|
||||
Description: indices is dynamic shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_dynamic_indices()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_dynamic_indices()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_dynamic_updates():
|
||||
"""
|
||||
Feature: test scatter_min dynamic shape.
|
||||
Description: updates is dynamic shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_dynamic_updates()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_dynamic_updates()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_grad_float16():
|
||||
"""
|
||||
Feature: test scatter_min grad.
|
||||
Description: test float16 inputs.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_grad(np.float16)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_grad(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_grad_float32():
|
||||
"""
|
||||
Feature: test scatter_min grad.
|
||||
Description: test float32 inputs.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_grad(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_grad(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_min_grad_int32():
|
||||
"""
|
||||
Feature: test scatter_min grad.
|
||||
Description: test int32 inputs.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
scatter_min_grad(np.int32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
scatter_min_grad(np.int32)
|
|
@ -775,11 +775,6 @@ def test_scatter_func_disordered_dynamic_int8():
|
|||
).astype(np.int8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -817,11 +812,6 @@ def test_scatter_func_disordered_dynamic_uint8():
|
|||
).astype(np.uint8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
Loading…
Reference in New Issue