forked from mindspore-Ecosystem/mindspore
!905 add topk op for aicpu
Merge pull request !905 from yanzhenxiang2020/add_topkop_for_aicpu
This commit is contained in:
commit
88215d0007
|
@ -111,6 +111,9 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
|
||||
CreateCpuKernelInfo(inputs, outputs);
|
||||
auto *stream = reinterpret_cast<rtStream_t *>(stream_ptr);
|
||||
if (node_name_ == "TopK") {
|
||||
node_name_ = "TopKV2";
|
||||
}
|
||||
MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_
|
||||
<< ", args_size:" << args_.length();
|
||||
if (rtCpuKernelLaunch(reinterpret_cast<const void *>(node_so_.c_str()),
|
||||
|
@ -137,6 +140,9 @@ vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr> &inp
|
|||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
|
||||
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||
|
||||
if (node_name_ == "TopK") {
|
||||
node_name_ = "TopKV2";
|
||||
}
|
||||
AicpuTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::AicpuTaskInfo>(
|
||||
stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs);
|
||||
|
||||
|
|
|
@ -568,6 +568,12 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<Ke
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list;
|
||||
|
||||
if (AnfAlgo::GetCNodeName(kernel_node) == kTopKOpName && AnfAlgo::GetNodeAttr<bool>(kernel_node, "sorted") == false) {
|
||||
MS_LOG(INFO) << "will select aicpu topk.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
|
||||
if (op_info_ptr == nullptr) {
|
||||
|
|
|
@ -17,3 +17,4 @@ from .init_data_set_queue import _init_data_set_queue_aicpu
|
|||
from .dropout_genmask import _dropout_genmask_aicpu
|
||||
from .get_next import _get_next_aicpu
|
||||
from .print_tensor import _print_aicpu
|
||||
from .topk import _top_k_aicpu
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""TopK op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
top_k_op_info = AiCPURegOp("TopK") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("sorted", "bool")\
|
||||
.input(0, "intput", "required") \
|
||||
.input(1, "k", "required") \
|
||||
.output(0, "values", "required") \
|
||||
.output(1, "indices", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(top_k_op_info)
|
||||
def _top_k_aicpu():
|
||||
"""TopK aicpu register"""
|
||||
return
|
Loading…
Reference in New Issue