forked from mindspore-Ecosystem/mindspore
ScatterMax support dynamic shape
This commit is contained in:
parent
974ee35a9e
commit
b177a68bfb
|
@ -0,0 +1,95 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ops/scatter_max.h"
|
||||||
|
#include <set>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr ScatterMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||||
|
BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||||
|
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||||
|
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||||
|
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||||
|
|
||||||
|
if (input_x_shape_ptr->IsDynamic()) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||||
|
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
|
||||||
|
<< input_x_shape_ptr->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||||
|
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||||
|
std::vector<int64_t> indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||||
|
std::vector<int64_t> updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||||
|
std::vector<int64_t> check_update_shape(indices_shape);
|
||||||
|
for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) {
|
||||||
|
check_update_shape.push_back(input_x_shape[i]);
|
||||||
|
}
|
||||||
|
if (updates_shape != check_update_shape) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||||
|
<< "updates_shape = indices_shape + x_shape[1:], but got x_shape: "
|
||||||
|
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||||
|
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||||
|
return output_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr ScatterMaxInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto input_x_type_ptr = input_args[kInputIndex0]->BuildType();
|
||||||
|
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||||
|
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
std::set<TypePtr> type_set = {kInt32};
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex,
|
||||||
|
prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex,
|
||||||
|
prim_name);
|
||||||
|
|
||||||
|
std::map<std::string, TypePtr> type_dict;
|
||||||
|
type_dict.emplace("input_x", input_x_type_ptr);
|
||||||
|
type_dict.emplace("updates", updates_type_ptr);
|
||||||
|
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(ScatterMax, BaseOperator);
|
||||||
|
AbstractBasePtr ScatterMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t input_num = 3;
|
||||||
|
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||||
|
auto infer_type = ScatterMaxInferType(primitive, input_args);
|
||||||
|
auto infer_shape = ScatterMaxInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMax, prim::kPrimScatterMax, ScatterMaxInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SCATTER_MAX_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SCATTER_MAX_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameScatterMax = "ScatterMax";
|
||||||
|
class MIND_API ScatterMax : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(ScatterMax);
|
||||||
|
/// \brief Constructor.
|
||||||
|
ScatterMax() : BaseOperator(kNameScatterMax) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
|
||||||
|
};
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr ScatterMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SCATTER_MAX_H_
|
|
@ -33,7 +33,13 @@ abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std
|
||||||
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||||
|
|
||||||
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
if (input_x_shape_ptr->IsDynamic()) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||||
|
<< "the 'input_x' does not support dynamic shape, but got the shape of 'input_x' is "
|
||||||
|
<< input_x_shape_ptr->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4230,7 +4230,7 @@ class ScatterNdUpdate(Primitive):
|
||||||
self.add_prim_attr('side_effect_mem', True)
|
self.add_prim_attr('side_effect_mem', True)
|
||||||
|
|
||||||
|
|
||||||
class ScatterMax(_ScatterOp):
|
class ScatterMax(_ScatterOpDynamic):
|
||||||
r"""
|
r"""
|
||||||
Updates the value of the input tensor through the maximum operation.
|
Updates the value of the input tensor through the maximum operation.
|
||||||
|
|
||||||
|
@ -4284,7 +4284,7 @@ class ScatterMax(_ScatterOp):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ScatterMin(_ScatterOp):
|
class ScatterMin(_ScatterOpDynamic):
|
||||||
r"""
|
r"""
|
||||||
Updates the value of the input tensor through the minimum operation.
|
Updates the value of the input tensor through the minimum operation.
|
||||||
|
|
||||||
|
|
|
@ -722,6 +722,13 @@ def test_scatter_func_disordered_dynamic_int32():
|
||||||
).astype(np.int32)
|
).astype(np.int32)
|
||||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
# max
|
||||||
|
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||||
|
expected = np.array(
|
||||||
|
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
|
||||||
|
).astype(np.int32)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
# min
|
# min
|
||||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
|
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
|
||||||
|
@ -761,6 +768,13 @@ def test_scatter_func_disordered_dynamic_int8():
|
||||||
).astype(np.int8)
|
).astype(np.int8)
|
||||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
# max
|
||||||
|
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||||
|
expected = np.array(
|
||||||
|
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
|
||||||
|
).astype(np.int8)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
# min
|
# min
|
||||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
|
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
|
||||||
|
@ -796,6 +810,13 @@ def test_scatter_func_disordered_dynamic_uint8():
|
||||||
).astype(np.uint8)
|
).astype(np.uint8)
|
||||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
# max
|
||||||
|
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||||
|
expected = np.array(
|
||||||
|
[[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
|
||||||
|
).astype(np.uint8)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
# min
|
# min
|
||||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
|
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
|
||||||
|
@ -848,6 +869,13 @@ def test_scatter_func_input_less_than_1_dynamic_float32():
|
||||||
)
|
)
|
||||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
# max
|
||||||
|
output = scatter_func_d_net("max", inputx, indices, updates)
|
||||||
|
expected = np.array(
|
||||||
|
[[37.0, 38.0, 39.0], [34.0, 35.0, 66.0], [67.0, 68.0, 69.0],], dtype=np.float32,
|
||||||
|
)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
# min
|
# min
|
||||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||||
expected = inputx_np
|
expected = inputx_np
|
||||||
|
@ -891,6 +919,15 @@ def test_scatter_func_dynamic_two_inputs():
|
||||||
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
|
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
|
||||||
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
|
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
|
||||||
|
|
||||||
|
# max
|
||||||
|
output_1, output_2 = scatter_func_d2_net(
|
||||||
|
"max", inputx, indices_1, updates_1, indices_2, updates_2
|
||||||
|
)
|
||||||
|
expected_1 = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
|
||||||
|
expected_2 = np.array([[17.0, 16.0, 15.0], [11.0, 10.0, 9.0]])
|
||||||
|
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
|
||||||
|
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
|
||||||
|
|
||||||
# min
|
# min
|
||||||
output_1, output_2 = scatter_func_d2_net(
|
output_1, output_2 = scatter_func_d2_net(
|
||||||
"min", inputx, indices_1, updates_1, indices_2, updates_2
|
"min", inputx, indices_1, updates_1, indices_2, updates_2
|
||||||
|
@ -902,13 +939,8 @@ def test_scatter_func_dynamic_two_inputs():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_scatter_func_small_float32()
|
test_scatter_func_disordered_dynamic_int32()
|
||||||
test_scatter_func_input_updated()
|
test_scatter_func_disordered_dynamic_int8()
|
||||||
test_scatter_func_large_shape_float32()
|
test_scatter_func_disordered_dynamic_uint8()
|
||||||
test_scatter_func_small_float32_use_locking_false()
|
test_scatter_func_input_less_than_1_dynamic_float32()
|
||||||
test_scatter_func_input_less_than_1_float32()
|
test_scatter_func_dynamic_two_inputs()
|
||||||
test_scatter_func_float16()
|
|
||||||
test_scatter_func_large_float16()
|
|
||||||
test_scatter_func_disordered_float16()
|
|
||||||
test_scatter_func_large_int32()
|
|
||||||
test_scatter_func_disordered_int32()
|
|
||||||
|
|
Loading…
Reference in New Issue