!44428 refactor ScatterAddWithAxis and SearchSorted oprerators kernelmod.

Merge pull request !44428 from yangshuo/br_scatter_add_with_axis
This commit is contained in:
i-robot 2022-10-25 13:23:49 +00:00 committed by Gitee
commit 9d365c662b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 165 additions and 42 deletions

View File

@ -20,6 +20,7 @@
#include <complex>
#include "kernel/common_utils.h"
#include "mindspore/core/ops/scatter_add_with_axis.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace {
@ -46,47 +47,68 @@ namespace {
.AddOutputAttr(kNumberType##t4)
const int32_t kInputNum = 3;
const int32_t kOutputNum = 1;
const uint32_t kInputIndex2 = 2;
const int32_t KSplitSize = 64 * 1024;
} // namespace
void ScatterAddWithAxisCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
x_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
indices_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
if (indices_type_ != kNumberTypeInt32 && indices_type_ != kNumberTypeInt64) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dtype of 'indices' must be int32 or int64, but got "
<< indices_type_;
bool ScatterAddWithAxisCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
auto op_prim = std::dynamic_pointer_cast<ops::ScatterAddWithAxis>(base_operator);
MS_ERROR_IF_NULL(op_prim);
axis_ = op_prim->get_axis();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
return true;
}
// check parameters basic attribution are valid
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
updates_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis");
int ScatterAddWithAxisCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
x_type_ = inputs[kIndex0]->GetDtype();
indices_type_ = inputs[kIndex1]->GetDtype();
x_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
indices_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
updates_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
// Get and check 3 input dim info
int64_t value_dim_num_x1 = static_cast<int64_t>(x_shape_.size());
int64_t value_dim_num_x2 = static_cast<int64_t>(indices_shape_.size());
int64_t value_dim_num_x3 = static_cast<int64_t>(updates_shape_.size());
if (value_dim_num_x1 != value_dim_num_x2 || value_dim_num_x2 != value_dim_num_x3) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dim values of three inputs must be same, but got "
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dim values of three inputs must be same, but got "
<< "data: " << value_dim_num_x1 << ", indices: " << value_dim_num_x2
<< ", update: " << value_dim_num_x3;
return KRET_RESIZE_FAILED;
}
if (axis_ < value_dim_num_x1 * -1 || axis_ >= value_dim_num_x1) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the value of axis is out of range!";
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of axis is out of range!";
return KRET_RESIZE_FAILED;
}
int64_t sub_data_fix = 1;
int64_t sub_index_fix = 1;
data_dim_vec_.clear();
index_dim_vec_.clear();
for (int64_t i = value_dim_num_x2 - 1; i >= 0; --i) {
size_t j = static_cast<size_t>(i);
if (x_shape_[j] < indices_shape_[j] || indices_shape_[j] != updates_shape_[j] || updates_shape_[j] <= 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the " << j << " dimension verification failed: "
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << j << " dimension verification failed: "
<< "input0[" << x_shape_[j] << "], input1[" << indices_shape_[j] << "], input2["
<< updates_shape_[j] << "]";
return KRET_RESIZE_FAILED;
}
if (i > 0) {
sub_data_fix *= x_shape_[j];
@ -95,6 +117,7 @@ void ScatterAddWithAxisCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
index_dim_vec_.push_back(sub_index_fix);
}
}
return KRET_OK;
}
bool ScatterAddWithAxisCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
@ -26,12 +27,16 @@
namespace mindspore {
namespace kernel {
class ScatterAddWithAxisCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class ScatterAddWithAxisCpuKernelMod : public NativeCpuKernelMod {
public:
ScatterAddWithAxisCpuKernelMod() = default;
~ScatterAddWithAxisCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

View File

@ -21,6 +21,7 @@
#include <functional>
#include <algorithm>
#include <utility>
#include "mindspore/core/ops/search_sorted.h"
namespace mindspore {
namespace kernel {
@ -29,20 +30,35 @@ constexpr size_t kSearchSortedInputsNum = 2;
constexpr size_t kSearchSortedOutputsNum = 1;
} // namespace
void SearchSortedCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
right_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "right");
sequence_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
values_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
search_len = LongToSize(sequence_shape_.back());
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
bool SearchSortedCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSearchSortedInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSearchSortedOutputsNum, kernel_name_);
auto op_prim = std::dynamic_pointer_cast<ops::SearchSorted>(base_operator);
right_ = op_prim->get_right();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "SearchSorted does not support this kernel data type: " << kernel_attr;
MS_LOG(ERROR) << "SearchSorted does not support this kernel data type: " << kernel_attr;
return true;
}
kernel_func_ = func_list_[index].second;
return true;
}
int SearchSortedCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
sequence_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
values_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
search_len_ = LongToSize(sequence_shape_.back());
return KRET_OK;
}
template <typename S>
@ -71,9 +87,9 @@ bool SearchSortedCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
auto task = [this, &sequence, &values, &output, seq_dim, search_repeat](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto seq_start = (seq_dim == 1) ? sequence : sequence + (i / search_repeat) * search_len;
auto result = right_ ? std::upper_bound(seq_start, seq_start + search_len, values[i]) - seq_start
: CustomizedLowerBound<S>(seq_start, seq_start + search_len, values[i]) - seq_start;
auto seq_start = (seq_dim == 1) ? sequence : sequence + (i / search_repeat) * search_len_;
auto result = right_ ? std::upper_bound(seq_start, seq_start + search_len_, values[i]) - seq_start
: CustomizedLowerBound<S>(seq_start, seq_start + search_len_, values[i]) - seq_start;
output[i] = static_cast<T>(result);
}
};
@ -97,10 +113,10 @@ void SearchSortedCpuKernelMod::CheckParam(const std::vector<AddressPtr> &inputs,
int list_count = accumulate(sequence_shape_.begin(), sequence_shape_.end() - 1, 1, std::multiplies<int>());
auto task = [this, &sequence](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
for (size_t j = 0; j < search_len - 1; j++) {
if (sequence[i * search_len + j] > sequence[i * search_len + j + 1]) {
for (size_t j = 0; j < search_len_ - 1; j++) {
if (sequence[i * search_len_ + j] > sequence[i * search_len_ + j + 1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input sequence must be forward sequence. But got "
<< sequence[i * search_len + j] << '>' << sequence[i * search_len + j + 1];
<< sequence[i * search_len_ + j] << '>' << sequence[i * search_len_ + j + 1];
}
}
}

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
#include <map>
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
@ -24,12 +25,16 @@
namespace mindspore {
namespace kernel {
class SearchSortedCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class SearchSortedCpuKernelMod : public NativeCpuKernelMod {
public:
SearchSortedCpuKernelMod() = default;
~SearchSortedCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
@ -51,7 +56,7 @@ class SearchSortedCpuKernelMod : public DeprecatedNativeCpuKernelMod {
SearchSortedFunc kernel_func_;
bool right_{false};
size_t search_len{0};
size_t search_len_{0};
std::vector<int64_t> sequence_shape_;
std::vector<int64_t> values_shape_;
std::vector<int64_t> output_shape_;

View File

@ -0,0 +1,30 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/search_sorted.h"
#include <string>
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameRight = "right";
void SearchSorted::set_right(const bool right) { (void)AddAttr(kNameRight, api::MakeValue(right)); }
bool SearchSorted::get_right() const { return GetValue<bool>(GetAttr(kNameRight)); }
MIND_API_OPERATOR_IMPL(SearchSorted, BaseOperator);
REGISTER_PRIMITIVE_C(kSearchSorted, SearchSorted);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
#define MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
#include <memory>
#include <vector>
#include "mindapi/base/types.h"
#include "ops/base_operator.h"
namespace mindspore {
namespace ops {
constexpr auto kSearchSorted = "SearchSorted";
/// \brief Updates tensor values by using input indices and value.
/// Refer to Python API @ref mindspore.ops.SearchSorted for more details.
class MIND_API SearchSorted : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SearchSorted);
/// \brief Constructor.
SearchSorted() : BaseOperator(kSearchSorted) { InitIOName({"sequence", "values"}, {"positions"}); }
/// \brief Set right.
void set_right(const bool right);
/// \brief Get right.
bool get_right() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SEARCH_SORTED_H_