forked from mindspore-Ecosystem/mindspore
scatter_min: add c++ implementation to adapt to dynamic shape
This commit is contained in:
parent
244e078ad3
commit
9cc94628fa
mindspore
tests/st/ops/gpu
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/scatter_min.h"
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ScatterMinInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
BaseShapePtr input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input_x_shape_ptr);
|
||||
BaseShapePtr indices_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
|
||||
BaseShapePtr updates_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(updates_shape_ptr);
|
||||
|
||||
if (input_x_shape_ptr->IsDynamic() || indices_shape_ptr->IsDynamic() || updates_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
|
||||
std::vector<int64_t> indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices_shape_ptr)[kShape];
|
||||
std::vector<int64_t> updates_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(updates_shape_ptr)[kShape];
|
||||
std::vector<int64_t> check_update_shape(indices_shape);
|
||||
for (int64_t i = 1; i < SizeToLong(input_x_shape.size()); ++i) {
|
||||
check_update_shape.push_back(input_x_shape[i]);
|
||||
}
|
||||
if (updates_shape != check_update_shape) {
|
||||
MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", "
|
||||
<< "updates_shape = indices_shape + x_shape[1:], but got x_shape: "
|
||||
<< input_x_shape_ptr->ToString() << ", indices_shape: " << indices_shape_ptr->ToString()
|
||||
<< ", updates_shape: " << updates_shape_ptr->ToString() << ".";
|
||||
}
|
||||
|
||||
auto output_shape = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
TypePtr ScatterMinInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_x_type_ptr = input_args[kInputIndex0]->BuildType();
|
||||
auto indiecs_type_ptr = input_args[kInputIndex1]->BuildType();
|
||||
auto updates_type_ptr = input_args[kInputIndex2]->BuildType();
|
||||
auto prim_name = primitive->name();
|
||||
std::set<TypePtr> type_set = {kInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indiecs_type_ptr, type_set, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type_ptr, common_valid_types_with_complex,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates type", updates_type_ptr, common_valid_types_with_complex,
|
||||
prim_name);
|
||||
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("input_x", input_x_type_ptr);
|
||||
type_dict.emplace("updates", updates_type_ptr);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(type_dict, common_valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ScatterMin, BaseOperator);
|
||||
AbstractBasePtr ScatterMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto infer_type = ScatterMinInferType(primitive, input_args);
|
||||
auto infer_shape = ScatterMinInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScatterMin, prim::kPrimScatterMin, ScatterMinInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SCATTER_MIN_H_
|
||||
#define MINDSPORE_CORE_OPS_SCATTER_MIN_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameScatterMin = "ScatterMin";
|
||||
class MIND_API ScatterMin : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ScatterMin);
|
||||
/// \brief Constructor.
|
||||
ScatterMin() : BaseOperator(kNameScatterMin) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr ScatterMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SCATTER_MIN_H_
|
|
@ -626,18 +626,31 @@ def scatter_max(input_x, indices, updates):
|
|||
return scatter_max_(input_x, indices, updates)
|
||||
|
||||
|
||||
def scatter_min(input_x, indices, updates, use_locking=False):
|
||||
r"""
|
||||
scatter_min_ = P.ScatterMin()
|
||||
def scatter_min(input_x, indices, updates):
|
||||
"""
|
||||
Updates the value of the input tensor through the minimum operation.
|
||||
|
||||
Using given values to update tensor value through the min operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
for each `i, ..., j` in `indices.shape`:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{input_x}[\text{indices}[i, ..., j], :]
|
||||
= min(\text{input_x}[\text{indices}[i, ..., j], :], \text{updates}[i, ..., j, :])
|
||||
|
||||
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
|
||||
If they have different data types, the lower priority data type will be converted to
|
||||
the relatively highest priority data type.
|
||||
|
||||
Args:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
|
||||
- use_locking (bool): Whether to protect the assignment by a lock. Default: False.
|
||||
|
||||
Outputs:
|
||||
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
|
||||
|
@ -666,7 +679,7 @@ def scatter_min(input_x, indices, updates, use_locking=False):
|
|||
[[0. 0. 0.]
|
||||
[0. 0. 0.]]
|
||||
"""
|
||||
return P.ScatterMin(use_locking)(input_x, indices, updates)
|
||||
return scatter_min_(input_x, indices, updates)
|
||||
|
||||
|
||||
scatter_nd_ = P.ScatterNd()
|
||||
|
|
|
@ -722,6 +722,11 @@ def test_scatter_func_disordered_dynamic_int32():
|
|||
).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -756,6 +761,11 @@ def test_scatter_func_disordered_dynamic_int8():
|
|||
).astype(np.int8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -786,20 +796,24 @@ def test_scatter_func_disordered_dynamic_uint8():
|
|||
).astype(np.uint8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_func_input_less_than_1_dynamic_float32():
|
||||
inputx = Tensor(
|
||||
np.array(
|
||||
[
|
||||
[0.214141, 0.415151, 0.51516],
|
||||
[0.876542, 0.451611, 0.55112],
|
||||
[0.111244, 0.633333, 0.34444],
|
||||
]
|
||||
).astype(np.float32)
|
||||
)
|
||||
inputx_np = np.array(
|
||||
[
|
||||
[0.214141, 0.415151, 0.51516],
|
||||
[0.876542, 0.451611, 0.55112],
|
||||
[0.111244, 0.633333, 0.34444],
|
||||
]
|
||||
).astype(np.float32)
|
||||
inputx = Tensor(inputx_np)
|
||||
indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
|
||||
updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32))
|
||||
|
||||
|
@ -834,6 +848,11 @@ def test_scatter_func_input_less_than_1_dynamic_float32():
|
|||
)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
# min
|
||||
output = scatter_func_d_net("min", inputx, indices, updates)
|
||||
expected = inputx_np
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -872,6 +891,16 @@ def test_scatter_func_dynamic_two_inputs():
|
|||
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
|
||||
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
|
||||
|
||||
# min
|
||||
output_1, output_2 = scatter_func_d2_net(
|
||||
"min", inputx, indices_1, updates_1, indices_2, updates_2
|
||||
)
|
||||
expected_1 = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||
expected_2 = expected_1
|
||||
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
|
||||
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_scatter_func_small_float32()
|
||||
test_scatter_func_input_updated()
|
||||
|
|
Loading…
Reference in New Issue