ScatterMax support dynamic shape

This commit is contained in:
polyhedral 2022-04-29 17:30:15 +08:00
parent 974ee35a9e
commit b177a68bfb
5 changed files with 185 additions and 13 deletions

View File

@ -0,0 +1,95 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/scatter_max.h"
#include <set>
#include <map>
#include <string>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ScatterMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
if (input_x_shape_ptr->IsDynamic()) {
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
<< input_x_shape_ptr->ToString();
}
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
std::vector<int64_t> indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
std::vector<int64_t> updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
std::vector<int64_t> check_update_shape(indices_shape);
for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) {
check_update_shape.push_back(input_x_shape[i]);
}
if (updates_shape != check_update_shape) {
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
<< "updates_shape = indices_shape + x_shape[1:], but got x_shape: "
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
}
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
return output_shape;
}
TypePtr ScatterMaxInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto input_x_type_ptr = input_args[kInputIndex0]->BuildType();
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
auto prim_name = primitive->name();
std::set<TypePtr> type_set = {kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex,
prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex,
prim_name);
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("input_x", input_x_type_ptr);
type_dict.emplace("updates", updates_type_ptr);
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
}
} // namespace
MIND_API_OPERATOR_IMPL(ScatterMax, BaseOperator);
AbstractBasePtr ScatterMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto infer_type = ScatterMaxInferType(primitive, input_args);
auto infer_shape = ScatterMaxInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMax, prim::kPrimScatterMax, ScatterMaxInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SCATTER_MAX_H_
#define MINDSPORE_CORE_OPS_SCATTER_MAX_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameScatterMax = "ScatterMax";
class MIND_API ScatterMax : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ScatterMax);
/// \brief Constructor.
ScatterMax() : BaseOperator(kNameScatterMax) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
};
abstract::AbstractBasePtr ScatterMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SCATTER_MAX_H_

View File

@ -33,7 +33,13 @@ abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
if (input_x_shape_ptr->IsDynamic()) {
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
<< input_x_shape_ptr->ToString();
}
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
}

View File

@ -4230,7 +4230,7 @@ class ScatterNdUpdate(Primitive):
self.add_prim_attr('side_effect_mem', True)
class ScatterMax(_ScatterOp):
class ScatterMax(_ScatterOpDynamic):
r"""
Updates the value of the input tensor through the maximum operation.
@ -4284,7 +4284,7 @@ class ScatterMax(_ScatterOp):
"""
class ScatterMin(_ScatterOp):
class ScatterMin(_ScatterOpDynamic):
r"""
Updates the value of the input tensor through the minimum operation.

View File

@ -722,6 +722,13 @@ def test_scatter_func_disordered_dynamic_int32():
).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# max
output = scatter_func_d_net("max", inputx, indices, updates)
expected = np.array(
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
@ -761,6 +768,13 @@ def test_scatter_func_disordered_dynamic_int8():
).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# max
output = scatter_func_d_net("max", inputx, indices, updates)
expected = np.array(
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
@ -796,6 +810,13 @@ def test_scatter_func_disordered_dynamic_uint8():
).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# max
output = scatter_func_d_net("max", inputx, indices, updates)
expected = np.array(
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
@ -848,6 +869,13 @@ def test_scatter_func_input_less_than_1_dynamic_float32():
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# max
output = scatter_func_d_net("max", inputx, indices, updates)
expected = np.array(
[[37.0, 38.0, 39.0], [34.0, 35.0, 66.0], [67.0, 68.0, 69.0],], dtype=np.float32,
)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
# min
output = scatter_func_d_net("min", inputx, indices, updates)
expected = inputx_np
@ -891,6 +919,15 @@ def test_scatter_func_dynamic_two_inputs():
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
# max
output_1, output_2 = scatter_func_d2_net(
"max", inputx, indices_1, updates_1, indices_2, updates_2
)
expected_1 = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
expected_2 = np.array([[17.0, 16.0, 15.0], [11.0, 10.0, 9.0]])
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
# min
output_1, output_2 = scatter_func_d2_net(
"min", inputx, indices_1, updates_1, indices_2, updates_2
@ -902,13 +939,8 @@ def test_scatter_func_dynamic_two_inputs():
if __name__ == "__main__":
test_scatter_func_small_float32()
test_scatter_func_input_updated()
test_scatter_func_large_shape_float32()
test_scatter_func_small_float32_use_locking_false()
test_scatter_func_input_less_than_1_float32()
test_scatter_func_float16()
test_scatter_func_large_float16()
test_scatter_func_disordered_float16()
test_scatter_func_large_int32()
test_scatter_func_disordered_int32()
test_scatter_func_disordered_dynamic_int32()
test_scatter_func_disordered_dynamic_int8()
test_scatter_func_disordered_dynamic_uint8()
test_scatter_func_input_less_than_1_dynamic_float32()
test_scatter_func_dynamic_two_inputs()