forked from mindspore-Ecosystem/mindspore
ScatterMax support dynamic shape
This commit is contained in:
parent
974ee35a9e
commit
b177a68bfb
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/scatter_max.h"
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ScatterMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
|
||||
if (input_x_shape_ptr->IsDynamic()) {
|
||||
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
|
||||
<< input_x_shape_ptr->ToString();
|
||||
}
|
||||
|
||||
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
std::vector<int64_t> indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
std::vector<int64_t> updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
std::vector<int64_t> check_update_shape(indices_shape);
|
||||
for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) {
|
||||
check_update_shape.push_back(input_x_shape[i]);
|
||||
}
|
||||
if (updates_shape != check_update_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||
<< "updates_shape = indices_shape + x_shape[1:], but got x_shape: "
|
||||
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
|
||||
}
|
||||
|
||||
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
TypePtr ScatterMaxInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_x_type_ptr = input_args[kInputIndex0]->BuildType();
|
||||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
|
||||
auto prim_name = primitive->name();
|
||||
std::set<TypePtr> type_set = {kInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex,
|
||||
prim_name);
|
||||
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_x_type_ptr);
|
||||
type_dict.emplace("updates", updates_type_ptr);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterMax, BaseOperator);
|
||||
AbstractBasePtr ScatterMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto infer_type = ScatterMaxInferType(primitive, input_args);
|
||||
auto infer_shape = ScatterMaxInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMax, prim::kPrimScatterMax, ScatterMaxInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_MAX_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_MAX_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterMax = "ScatterMax";
|
||||
class MIND_API ScatterMax : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScatterMax);
|
||||
/// \brief Constructor.
|
||||
ScatterMax() : BaseOperator(kNameScatterMax) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr ScatterMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_MAX_H_
|
|
@ -33,7 +33,13 @@ abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std
|
|||
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
|
||||
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
if (input_x_shape_ptr->IsDynamic()) {
|
||||
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
|
||||
<< input_x_shape_ptr->ToString();
|
||||
}
|
||||
|
||||
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
|
|
|
@ -4230,7 +4230,7 @@ class ScatterNdUpdate(Primitive):
|
|||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
class ScatterMax(_ScatterOp):
|
||||
class ScatterMax(_ScatterOpDynamic):
|
||||
r"""
|
||||
Updates the value of the input tensor through the maximum operation.
|
||||
|
||||
|
@ -4284,7 +4284,7 @@ class ScatterMax(_ScatterOp):
|
|||
"""
|
||||
|
||||
|
||||
class ScatterMin(_ScatterOp):
|
||||
class ScatterMin(_ScatterOpDynamic):
|
||||
r"""
|
||||
Updates the value of the input tensor through the minimum operation.
|
||||
|
||||
|
|
|
@ -722,6 +722,13 @@ def test_scatter_func_disordered_dynamic_int32():
|
|||
).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# max
|
||||
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||
expected = np.array(
|
||||
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
|
||||
).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
|
||||
|
@ -761,6 +768,13 @@ def test_scatter_func_disordered_dynamic_int8():
|
|||
).astype(np.int8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# max
|
||||
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||
expected = np.array(
|
||||
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
|
||||
).astype(np.int8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
|
||||
|
@ -796,6 +810,13 @@ def test_scatter_func_disordered_dynamic_uint8():
|
|||
).astype(np.uint8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# max
|
||||
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||
expected = np.array(
|
||||
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
|
||||
).astype(np.uint8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
|
||||
|
@ -848,6 +869,13 @@ def test_scatter_func_input_less_than_1_dynamic_float32():
|
|||
)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# max
|
||||
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||
expected = np.array(
|
||||
[[37.0, 38.0, 39.0], [34.0, 35.0, 66.0], [67.0, 68.0, 69.0],], dtype=np.float32,
|
||||
)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = inputx_np
|
||||
|
@ -891,6 +919,15 @@ def test_scatter_func_dynamic_two_inputs():
|
|||
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
|
||||
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
|
||||
|
||||
# max
|
||||
output_1, output_2 = scatter_func_d2_net(
|
||||
"max", inputx, indices_1, updates_1, indices_2, updates_2
|
||||
)
|
||||
expected_1 = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
|
||||
expected_2 = np.array([[17.0, 16.0, 15.0], [11.0, 10.0, 9.0]])
|
||||
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
|
||||
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
|
||||
|
||||
# min
|
||||
output_1, output_2 = scatter_func_d2_net(
|
||||
"min", inputx, indices_1, updates_1, indices_2, updates_2
|
||||
|
@ -902,13 +939,8 @@ def test_scatter_func_dynamic_two_inputs():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_scatter_func_small_float32()
|
||||
test_scatter_func_input_updated()
|
||||
test_scatter_func_large_shape_float32()
|
||||
test_scatter_func_small_float32_use_locking_false()
|
||||
test_scatter_func_input_less_than_1_float32()
|
||||
test_scatter_func_float16()
|
||||
test_scatter_func_large_float16()
|
||||
test_scatter_func_disordered_float16()
|
||||
test_scatter_func_large_int32()
|
||||
test_scatter_func_disordered_int32()
|
||||
test_scatter_func_disordered_dynamic_int32()
|
||||
test_scatter_func_disordered_dynamic_int8()
|
||||
test_scatter_func_disordered_dynamic_uint8()
|
||||
test_scatter_func_input_less_than_1_dynamic_float32()
|
||||
test_scatter_func_dynamic_two_inputs()
|
||||
|
|
Loading…
Reference in New Issue