forked from mindspore-Ecosystem/mindspore
!44428 refactor ScatterAddWithAxis and SearchSorted oprerators kernelmod.
Merge pull request !44428 from yangshuo/br_scatter_add_with_axis
This commit is contained in:
commit
9d365c662b
|
@ -20,6 +20,7 @@
|
|||
#include <complex>
|
||||
|
||||
#include "kernel/common_utils.h"
|
||||
#include "mindspore/core/ops/scatter_add_with_axis.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace {
|
||||
|
@ -46,47 +47,68 @@ namespace {
|
|||
.AddOutputAttr(kNumberType##t4)
|
||||
const int32_t kInputNum = 3;
|
||||
const int32_t kOutputNum = 1;
|
||||
const uint32_t kInputIndex2 = 2;
|
||||
const int32_t KSplitSize = 64 * 1024;
|
||||
} // namespace
|
||||
|
||||
void ScatterAddWithAxisCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
x_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
indices_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
if (indices_type_ != kNumberTypeInt32 && indices_type_ != kNumberTypeInt64) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dtype of 'indices' must be int32 or int64, but got "
|
||||
<< indices_type_;
|
||||
bool ScatterAddWithAxisCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_ERROR_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
auto op_prim = std::dynamic_pointer_cast<ops::ScatterAddWithAxis>(base_operator);
|
||||
MS_ERROR_IF_NULL(op_prim);
|
||||
axis_ = op_prim->get_axis();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match.first) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// check parameters basic attribution are valid
|
||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
updates_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
|
||||
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis");
|
||||
int ScatterAddWithAxisCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
x_type_ = inputs[kIndex0]->GetDtype();
|
||||
indices_type_ = inputs[kIndex1]->GetDtype();
|
||||
x_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
indices_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||
updates_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
|
||||
|
||||
// Get and check 3 input dim info
|
||||
int64_t value_dim_num_x1 = static_cast<int64_t>(x_shape_.size());
|
||||
int64_t value_dim_num_x2 = static_cast<int64_t>(indices_shape_.size());
|
||||
int64_t value_dim_num_x3 = static_cast<int64_t>(updates_shape_.size());
|
||||
if (value_dim_num_x1 != value_dim_num_x2 || value_dim_num_x2 != value_dim_num_x3) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dim values of three inputs must be same, but got "
|
||||
<< "data: " << value_dim_num_x1 << ", indices: " << value_dim_num_x2
|
||||
<< ", update: " << value_dim_num_x3;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dim values of three inputs must be same, but got "
|
||||
<< "data: " << value_dim_num_x1 << ", indices: " << value_dim_num_x2
|
||||
<< ", update: " << value_dim_num_x3;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (axis_ < value_dim_num_x1 * -1 || axis_ >= value_dim_num_x1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the value of axis is out of range!";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of axis is out of range!";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
int64_t sub_data_fix = 1;
|
||||
int64_t sub_index_fix = 1;
|
||||
data_dim_vec_.clear();
|
||||
index_dim_vec_.clear();
|
||||
for (int64_t i = value_dim_num_x2 - 1; i >= 0; --i) {
|
||||
size_t j = static_cast<size_t>(i);
|
||||
if (x_shape_[j] < indices_shape_[j] || indices_shape_[j] != updates_shape_[j] || updates_shape_[j] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the " << j << " dimension verification failed: "
|
||||
<< "input0[" << x_shape_[j] << "], input1[" << indices_shape_[j] << "], input2["
|
||||
<< updates_shape_[j] << "]";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << j << " dimension verification failed: "
|
||||
<< "input0[" << x_shape_[j] << "], input1[" << indices_shape_[j] << "], input2["
|
||||
<< updates_shape_[j] << "]";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (i > 0) {
|
||||
sub_data_fix *= x_shape_[j];
|
||||
|
@ -95,6 +117,7 @@ void ScatterAddWithAxisCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
index_dim_vec_.push_back(sub_index_fix);
|
||||
}
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool ScatterAddWithAxisCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -26,12 +27,16 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ScatterAddWithAxisCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class ScatterAddWithAxisCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ScatterAddWithAxisCpuKernelMod() = default;
|
||||
~ScatterAddWithAxisCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "mindspore/core/ops/search_sorted.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -29,20 +30,35 @@ constexpr size_t kSearchSortedInputsNum = 2;
|
|||
constexpr size_t kSearchSortedOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
void SearchSortedCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
right_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "right");
|
||||
sequence_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
values_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
search_len = LongToSize(sequence_shape_.back());
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
bool SearchSortedCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_ERROR_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSearchSortedInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSearchSortedOutputsNum, kernel_name_);
|
||||
auto op_prim = std::dynamic_pointer_cast<ops::SearchSorted>(base_operator);
|
||||
right_ = op_prim->get_right();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "SearchSorted does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(ERROR) << "SearchSorted does not support this kernel data type: " << kernel_attr;
|
||||
return true;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int SearchSortedCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
sequence_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
values_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||
search_len_ = LongToSize(sequence_shape_.back());
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
|
@ -71,9 +87,9 @@ bool SearchSortedCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
|
||||
auto task = [this, &sequence, &values, &output, seq_dim, search_repeat](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
auto seq_start = (seq_dim == 1) ? sequence : sequence + (i / search_repeat) * search_len;
|
||||
auto result = right_ ? std::upper_bound(seq_start, seq_start + search_len, values[i]) - seq_start
|
||||
: CustomizedLowerBound<S>(seq_start, seq_start + search_len, values[i]) - seq_start;
|
||||
auto seq_start = (seq_dim == 1) ? sequence : sequence + (i / search_repeat) * search_len_;
|
||||
auto result = right_ ? std::upper_bound(seq_start, seq_start + search_len_, values[i]) - seq_start
|
||||
: CustomizedLowerBound<S>(seq_start, seq_start + search_len_, values[i]) - seq_start;
|
||||
output[i] = static_cast<T>(result);
|
||||
}
|
||||
};
|
||||
|
@ -97,10 +113,10 @@ void SearchSortedCpuKernelMod::CheckParam(const std::vector<AddressPtr> &inputs,
|
|||
int list_count = accumulate(sequence_shape_.begin(), sequence_shape_.end() - 1, 1, std::multiplies<int>());
|
||||
auto task = [this, &sequence](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t j = 0; j < search_len - 1; j++) {
|
||||
if (sequence[i * search_len + j] > sequence[i * search_len + j + 1]) {
|
||||
for (size_t j = 0; j < search_len_ - 1; j++) {
|
||||
if (sequence[i * search_len_ + j] > sequence[i * search_len_ + j + 1]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input sequence must be forward sequence. But got "
|
||||
<< sequence[i * search_len + j] << '>' << sequence[i * search_len + j + 1];
|
||||
<< sequence[i * search_len_ + j] << '>' << sequence[i * search_len_ + j + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
|
@ -24,12 +25,16 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SearchSortedCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class SearchSortedCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
SearchSortedCpuKernelMod() = default;
|
||||
~SearchSortedCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
|
@ -51,7 +56,7 @@ class SearchSortedCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
SearchSortedFunc kernel_func_;
|
||||
|
||||
bool right_{false};
|
||||
size_t search_len{0};
|
||||
size_t search_len_{0};
|
||||
std::vector<int64_t> sequence_shape_;
|
||||
std::vector<int64_t> values_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/search_sorted.h"
|
||||
#include <string>
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameRight = "right";
|
||||
void SearchSorted::set_right(const bool right) { (void)AddAttr(kNameRight, api::MakeValue(right)); }
|
||||
bool SearchSorted::get_right() const { return GetValue<bool>(GetAttr(kNameRight)); }
|
||||
MIND_API_OPERATOR_IMPL(SearchSorted, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kSearchSorted, SearchSorted);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
|
||||
#define MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kSearchSorted = "SearchSorted";
|
||||
/// \brief Updates tensor values by using input indices and value.
|
||||
/// Refer to Python API @ref mindspore.ops.SearchSorted for more details.
|
||||
class MIND_API SearchSorted : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SearchSorted);
|
||||
/// \brief Constructor.
|
||||
SearchSorted() : BaseOperator(kSearchSorted) { InitIOName({"sequence", "values"}, {"positions"}); }
|
||||
|
||||
/// \brief Set right.
|
||||
void set_right(const bool right);
|
||||
/// \brief Get right.
|
||||
bool get_right() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
|
Loading…
Reference in New Issue