forked from mindspore-Ecosystem/mindspore
!44428 refactor ScatterAddWithAxis and SearchSorted oprerators kernelmod.
Merge pull request !44428 from yangshuo/br_scatter_add_with_axis
This commit is contained in:
commit
9d365c662b
|
@ -20,6 +20,7 @@
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
#include "kernel/common_utils.h"
|
#include "kernel/common_utils.h"
|
||||||
|
#include "mindspore/core/ops/scatter_add_with_axis.h"
|
||||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -46,47 +47,68 @@ namespace {
|
||||||
.AddOutputAttr(kNumberType##t4)
|
.AddOutputAttr(kNumberType##t4)
|
||||||
const int32_t kInputNum = 3;
|
const int32_t kInputNum = 3;
|
||||||
const int32_t kOutputNum = 1;
|
const int32_t kOutputNum = 1;
|
||||||
const uint32_t kInputIndex2 = 2;
|
|
||||||
const int32_t KSplitSize = 64 * 1024;
|
const int32_t KSplitSize = 64 * 1024;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ScatterAddWithAxisCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
bool ScatterAddWithAxisCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
const std::vector<KernelTensorPtr> &inputs,
|
||||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
x_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
MS_ERROR_IF_NULL(base_operator);
|
||||||
indices_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
kernel_name_ = base_operator->name();
|
||||||
if (indices_type_ != kNumberTypeInt32 && indices_type_ != kNumberTypeInt64) {
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dtype of 'indices' must be int32 or int64, but got "
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||||
<< indices_type_;
|
auto op_prim = std::dynamic_pointer_cast<ops::ScatterAddWithAxis>(base_operator);
|
||||||
|
MS_ERROR_IF_NULL(op_prim);
|
||||||
|
axis_ = op_prim->get_axis();
|
||||||
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match.first) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// check parameters basic attribution are valid
|
int ScatterAddWithAxisCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||||
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
const std::vector<KernelTensorPtr> &inputs,
|
||||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
updates_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
|
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||||
axis_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis");
|
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||||
|
if (ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
x_type_ = inputs[kIndex0]->GetDtype();
|
||||||
|
indices_type_ = inputs[kIndex1]->GetDtype();
|
||||||
|
x_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||||
|
indices_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||||
|
updates_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
|
||||||
|
|
||||||
// Get and check 3 input dim info
|
// Get and check 3 input dim info
|
||||||
int64_t value_dim_num_x1 = static_cast<int64_t>(x_shape_.size());
|
int64_t value_dim_num_x1 = static_cast<int64_t>(x_shape_.size());
|
||||||
int64_t value_dim_num_x2 = static_cast<int64_t>(indices_shape_.size());
|
int64_t value_dim_num_x2 = static_cast<int64_t>(indices_shape_.size());
|
||||||
int64_t value_dim_num_x3 = static_cast<int64_t>(updates_shape_.size());
|
int64_t value_dim_num_x3 = static_cast<int64_t>(updates_shape_.size());
|
||||||
if (value_dim_num_x1 != value_dim_num_x2 || value_dim_num_x2 != value_dim_num_x3) {
|
if (value_dim_num_x1 != value_dim_num_x2 || value_dim_num_x2 != value_dim_num_x3) {
|
||||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dim values of three inputs must be same, but got "
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dim values of three inputs must be same, but got "
|
||||||
<< "data: " << value_dim_num_x1 << ", indices: " << value_dim_num_x2
|
<< "data: " << value_dim_num_x1 << ", indices: " << value_dim_num_x2
|
||||||
<< ", update: " << value_dim_num_x3;
|
<< ", update: " << value_dim_num_x3;
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
}
|
}
|
||||||
if (axis_ < value_dim_num_x1 * -1 || axis_ >= value_dim_num_x1) {
|
if (axis_ < value_dim_num_x1 * -1 || axis_ >= value_dim_num_x1) {
|
||||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the value of axis is out of range!";
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of axis is out of range!";
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t sub_data_fix = 1;
|
int64_t sub_data_fix = 1;
|
||||||
int64_t sub_index_fix = 1;
|
int64_t sub_index_fix = 1;
|
||||||
|
data_dim_vec_.clear();
|
||||||
|
index_dim_vec_.clear();
|
||||||
for (int64_t i = value_dim_num_x2 - 1; i >= 0; --i) {
|
for (int64_t i = value_dim_num_x2 - 1; i >= 0; --i) {
|
||||||
size_t j = static_cast<size_t>(i);
|
size_t j = static_cast<size_t>(i);
|
||||||
if (x_shape_[j] < indices_shape_[j] || indices_shape_[j] != updates_shape_[j] || updates_shape_[j] <= 0) {
|
if (x_shape_[j] < indices_shape_[j] || indices_shape_[j] != updates_shape_[j] || updates_shape_[j] <= 0) {
|
||||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the " << j << " dimension verification failed: "
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << j << " dimension verification failed: "
|
||||||
<< "input0[" << x_shape_[j] << "], input1[" << indices_shape_[j] << "], input2["
|
<< "input0[" << x_shape_[j] << "], input1[" << indices_shape_[j] << "], input2["
|
||||||
<< updates_shape_[j] << "]";
|
<< updates_shape_[j] << "]";
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
}
|
}
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
sub_data_fix *= x_shape_[j];
|
sub_data_fix *= x_shape_[j];
|
||||||
|
@ -95,6 +117,7 @@ void ScatterAddWithAxisCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
index_dim_vec_.push_back(sub_index_fix);
|
index_dim_vec_.push_back(sub_index_fix);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return KRET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ScatterAddWithAxisCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
bool ScatterAddWithAxisCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SCATTER_ADD_WITH_AXIS_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -26,12 +27,16 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class ScatterAddWithAxisCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
class ScatterAddWithAxisCpuKernelMod : public NativeCpuKernelMod {
|
||||||
public:
|
public:
|
||||||
ScatterAddWithAxisCpuKernelMod() = default;
|
ScatterAddWithAxisCpuKernelMod() = default;
|
||||||
~ScatterAddWithAxisCpuKernelMod() override = default;
|
~ScatterAddWithAxisCpuKernelMod() override = default;
|
||||||
|
|
||||||
void InitKernel(const CNodePtr &kernel_node) override;
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include "mindspore/core/ops/search_sorted.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -29,20 +30,35 @@ constexpr size_t kSearchSortedInputsNum = 2;
|
||||||
constexpr size_t kSearchSortedOutputsNum = 1;
|
constexpr size_t kSearchSortedOutputsNum = 1;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SearchSortedCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
bool SearchSortedCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
MS_ERROR_IF_NULL(base_operator);
|
||||||
right_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "right");
|
kernel_name_ = base_operator->name();
|
||||||
sequence_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSearchSortedInputsNum, kernel_name_);
|
||||||
values_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSearchSortedOutputsNum, kernel_name_);
|
||||||
search_len = LongToSize(sequence_shape_.back());
|
auto op_prim = std::dynamic_pointer_cast<ops::SearchSorted>(base_operator);
|
||||||
|
right_ = op_prim->get_right();
|
||||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
if (!is_match) {
|
if (!is_match) {
|
||||||
MS_LOG(EXCEPTION) << "SearchSorted does not support this kernel data type: " << kernel_attr;
|
MS_LOG(ERROR) << "SearchSorted does not support this kernel data type: " << kernel_attr;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
kernel_func_ = func_list_[index].second;
|
kernel_func_ = func_list_[index].second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int SearchSortedCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||||
|
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||||
|
if (ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
sequence_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||||
|
values_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||||
|
search_len_ = LongToSize(sequence_shape_.back());
|
||||||
|
return KRET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename S>
|
template <typename S>
|
||||||
|
@ -71,9 +87,9 @@ bool SearchSortedCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
||||||
|
|
||||||
auto task = [this, &sequence, &values, &output, seq_dim, search_repeat](size_t start, size_t end) {
|
auto task = [this, &sequence, &values, &output, seq_dim, search_repeat](size_t start, size_t end) {
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
auto seq_start = (seq_dim == 1) ? sequence : sequence + (i / search_repeat) * search_len;
|
auto seq_start = (seq_dim == 1) ? sequence : sequence + (i / search_repeat) * search_len_;
|
||||||
auto result = right_ ? std::upper_bound(seq_start, seq_start + search_len, values[i]) - seq_start
|
auto result = right_ ? std::upper_bound(seq_start, seq_start + search_len_, values[i]) - seq_start
|
||||||
: CustomizedLowerBound<S>(seq_start, seq_start + search_len, values[i]) - seq_start;
|
: CustomizedLowerBound<S>(seq_start, seq_start + search_len_, values[i]) - seq_start;
|
||||||
output[i] = static_cast<T>(result);
|
output[i] = static_cast<T>(result);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -97,10 +113,10 @@ void SearchSortedCpuKernelMod::CheckParam(const std::vector<AddressPtr> &inputs,
|
||||||
int list_count = accumulate(sequence_shape_.begin(), sequence_shape_.end() - 1, 1, std::multiplies<int>());
|
int list_count = accumulate(sequence_shape_.begin(), sequence_shape_.end() - 1, 1, std::multiplies<int>());
|
||||||
auto task = [this, &sequence](size_t start, size_t end) {
|
auto task = [this, &sequence](size_t start, size_t end) {
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
for (size_t j = 0; j < search_len - 1; j++) {
|
for (size_t j = 0; j < search_len_ - 1; j++) {
|
||||||
if (sequence[i * search_len + j] > sequence[i * search_len + j + 1]) {
|
if (sequence[i * search_len_ + j] > sequence[i * search_len_ + j + 1]) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input sequence must be forward sequence. But got "
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input sequence must be forward sequence. But got "
|
||||||
<< sequence[i * search_len + j] << '>' << sequence[i * search_len + j + 1];
|
<< sequence[i * search_len_ + j] << '>' << sequence[i * search_len_ + j + 1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
@ -24,12 +25,16 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class SearchSortedCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
class SearchSortedCpuKernelMod : public NativeCpuKernelMod {
|
||||||
public:
|
public:
|
||||||
SearchSortedCpuKernelMod() = default;
|
SearchSortedCpuKernelMod() = default;
|
||||||
~SearchSortedCpuKernelMod() override = default;
|
~SearchSortedCpuKernelMod() override = default;
|
||||||
|
|
||||||
void InitKernel(const CNodePtr &kernel_node) override;
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) override {
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
|
@ -51,7 +56,7 @@ class SearchSortedCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||||
SearchSortedFunc kernel_func_;
|
SearchSortedFunc kernel_func_;
|
||||||
|
|
||||||
bool right_{false};
|
bool right_{false};
|
||||||
size_t search_len{0};
|
size_t search_len_{0};
|
||||||
std::vector<int64_t> sequence_shape_;
|
std::vector<int64_t> sequence_shape_;
|
||||||
std::vector<int64_t> values_shape_;
|
std::vector<int64_t> values_shape_;
|
||||||
std::vector<int64_t> output_shape_;
|
std::vector<int64_t> output_shape_;
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/search_sorted.h"
|
||||||
|
#include <string>
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameRight = "right";
|
||||||
|
void SearchSorted::set_right(const bool right) { (void)AddAttr(kNameRight, api::MakeValue(right)); }
|
||||||
|
bool SearchSorted::get_right() const { return GetValue<bool>(GetAttr(kNameRight)); }
|
||||||
|
MIND_API_OPERATOR_IMPL(SearchSorted, BaseOperator);
|
||||||
|
REGISTER_PRIMITIVE_C(kSearchSorted, SearchSorted);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kSearchSorted = "SearchSorted";
|
||||||
|
/// \brief Updates tensor values by using input indices and value.
|
||||||
|
/// Refer to Python API @ref mindspore.ops.SearchSorted for more details.
|
||||||
|
class MIND_API SearchSorted : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(SearchSorted);
|
||||||
|
/// \brief Constructor.
|
||||||
|
SearchSorted() : BaseOperator(kSearchSorted) { InitIOName({"sequence", "values"}, {"positions"}); }
|
||||||
|
|
||||||
|
/// \brief Set right.
|
||||||
|
void set_right(const bool right);
|
||||||
|
/// \brief Get right.
|
||||||
|
bool get_right() const;
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_SEARCH_SORTED_H_
|
Loading…
Reference in New Issue