scatter_min: add c++ implementation to adapt to dynamic shape

This commit is contained in:
hujiahui8 2022-04-28 20:09:08 +08:00
parent 244e078ad3
commit 9cc94628fa
4 changed files with 183 additions and 13 deletions
mindspore
core/ops
python/mindspore/ops/function
tests/st/ops/gpu

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/scatter_min.h"
#include <set>
#include <map>
#include <string>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
std::vector<int64_t> indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
std::vector<int64_t> updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
std::vector<int64_t> check_update_shape(indices_shape);
for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) {
check_update_shape.push_back(input_x_shape[i]);
}
if (updates_shape != check_update_shape) {
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
<< "updates_shape = indices_shape + x_shape[1:], but got x_shape: "
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
}
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
return output_shape;
}
TypePtr ScatterMinInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto input_x_type_ptr = input_args[kInputIndex0]->BuildType();
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
auto prim_name = primitive->name();
std::set<TypePtr> type_set = {kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex,
prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex,
prim_name);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("input_x", input_x_type_ptr);
type_dict.emplace("updates", updates_type_ptr);
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(ScatterMin, BaseOperator);
AbstractBasePtr ScatterMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = ScatterMinInferType(primitive, input_args);
auto infer_shape = ScatterMinInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMin, prim::kPrimScatterMin, ScatterMinInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCATTER_MIN_H_
#define MINDSPORE_CORE_OPS_SCATTER_MIN_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameScatterMin = "ScatterMin";
class MIND_API ScatterMin : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScatterMin);
/// \brief Constructor.
ScatterMin() : BaseOperator(kNameScatterMin) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
};
abstract::AbstractBasePtr ScatterMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCATTER_MIN_H_

View File

@ -626,18 +626,31 @@ def scatter_max(input_x, indices, updates):
return scatter_max_(input_x, indices, updates)
def scatter_min(input_x, indices, updates, use_locking=False):
r"""
scatter_min_ = P.ScatterMin()
def scatter_min(input_x, indices, updates):
"""
Updates the value of the input tensor through the minimum operation.
Using given values to update tensor value through the min operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
for each `i, ..., j` in `indices.shape`:
.. math::
\text{input_x}[\text{indices}[i, ..., j], :]
= min(\text{input_x}[\text{indices}[i, ..., j], :], \text{updates}[i, ..., j, :])
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, the lower priority data type will be converted to
the relatively highest priority data type.
Args:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
- use_locking (bool): Whether to protect the assignment by a lock. Default: False.
Outputs:
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
@ -666,7 +679,7 @@ def scatter_min(input_x, indices, updates, use_locking=False):
[[0. 0. 0.]
[0. 0. 0.]]
"""
return P.ScatterMin(use_locking)(input_x, indices, updates)
return scatter_min_(input_x, indices, updates)
scatter_nd_ = P.ScatterNd()

View File

@ -722,6 +722,11 @@ def test_scatter_func_disordered_dynamic_int32():
).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -756,6 +761,11 @@ def test_scatter_func_disordered_dynamic_int8():
).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -786,20 +796,24 @@ def test_scatter_func_disordered_dynamic_uint8():
).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_func_input_less_than_1_dynamic_float32():
inputx = Tensor(
np.array(
[
[0.214141, 0.415151, 0.51516],
[0.876542, 0.451611, 0.55112],
[0.111244, 0.633333, 0.34444],
]
).astype(np.float32)
)
inputx_np = np.array(
[
[0.214141, 0.415151, 0.51516],
[0.876542, 0.451611, 0.55112],
[0.111244, 0.633333, 0.34444],
]
).astype(np.float32)
inputx = Tensor(inputx_np)
indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32))
@ -834,6 +848,11 @@ def test_scatter_func_input_less_than_1_dynamic_float32():
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = inputx_np
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -872,6 +891,16 @@ def test_scatter_func_dynamic_two_inputs():
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
# min
output_1, output_2 = scatter_func_d2_net(
"min", inputx, indices_1, updates_1, indices_2, updates_2
)
expected_1 = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
expected_2 = expected_1
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
if __name__ == "__main__":
test_scatter_func_small_float32()
test_scatter_func_input_updated()