forked from mindspore-Ecosystem/mindspore
!30929 Add AkgKernel for lite_adapter
Merge pull request !30929 from DeshiChen/0307_akgkernel
This commit is contained in:
commit
f3b753807b
|
@ -20,10 +20,12 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/lite_adapter/akg_build.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ops/custom.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -57,6 +59,29 @@ AnfNodePtr KernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, const C
|
|||
auto kernel_name = GetValue<std::string>(fg->get_attr("kernel_name"));
|
||||
std::vector<uint8_t> kernel_name_str(kernel_name.begin(), kernel_name.end());
|
||||
custom_attrs["kernel_name"] = kernel_name_str;
|
||||
std::string output_shape_str;
|
||||
std::string output_format_str;
|
||||
std::string output_type_str;
|
||||
auto output_num = AnfUtils::GetOutputTensorNum(cnode);
|
||||
auto cb = Callback::Instance();
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
auto output_shape = cb->GetOutputShape(cnode, i);
|
||||
output_shape_str += std::to_string(output_shape.size()) + ",";
|
||||
for (auto &v : output_shape) {
|
||||
output_shape_str += std::to_string(v) + ",";
|
||||
}
|
||||
auto output_format = cb->GetOutputFormat(cnode, i);
|
||||
if (output_format == kOpFormat_NHWC) {
|
||||
output_format_str += "1,";
|
||||
} else { // default, NCHW
|
||||
output_format_str += "0,";
|
||||
}
|
||||
auto output_type = cb->GetOutputType(cnode, i);
|
||||
output_type_str += std::to_string(static_cast<int>(output_type)) + ",";
|
||||
}
|
||||
custom_attrs["outputs_shape"] = std::vector<uint8_t>(output_shape_str.begin(), output_shape_str.end());
|
||||
custom_attrs["outputs_format"] = std::vector<uint8_t>(output_format_str.begin(), output_format_str.end());
|
||||
custom_attrs["outputs_type"] = std::vector<uint8_t>(output_type_str.begin(), output_type_str.end());
|
||||
primc->set_attr(custom_attrs);
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_LITE_ADAPTER_COMMON_GRAPH_KERNEL_OP_PARAMETER_H_
|
||||
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_LITE_ADAPTER_COMMON_GRAPH_KERNEL_OP_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct GraphKernelParameter {
|
||||
OpParameter op_parameter_;
|
||||
void *prim_;
|
||||
} GraphKernelParameter;
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_LITE_ADAPTER_COMMON_GRAPH_KERNEL_OP_PARAMETER_H_
|
|
@ -0,0 +1,153 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
#include "common/graph_kernel/lite_adapter/common/graph_kernel_op_parameter.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
std::vector<std::string> SplitString(const std::string &raw_str, char delimiter) {
|
||||
std::vector<std::string> res;
|
||||
std::string::size_type last_pos = 0;
|
||||
auto cur_pos = raw_str.find(delimiter);
|
||||
while (cur_pos != std::string::npos) {
|
||||
(void)res.emplace_back(raw_str.substr(last_pos, cur_pos - last_pos));
|
||||
cur_pos++;
|
||||
last_pos = cur_pos;
|
||||
cur_pos = raw_str.find(delimiter, cur_pos);
|
||||
}
|
||||
if (last_pos < raw_str.size()) {
|
||||
(void)res.emplace_back(raw_str.substr(last_pos, raw_str.size() - last_pos + 1));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int GetCustomShape(const std::string &attr, std::vector<std::vector<int>> *shapes) {
|
||||
auto split_shape_str = SplitString(attr, ',');
|
||||
for (size_t i = 0; i < split_shape_str.size(); i++) {
|
||||
size_t dim = std::stoul(split_shape_str[i]);
|
||||
if (i + dim >= split_shape_str.size()) {
|
||||
MS_LOG(ERROR) << "Shape string is invalid. The shape dim is " << dim << ", but only "
|
||||
<< split_shape_str.size() - i << " values follow.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int> shape;
|
||||
for (size_t j = i + 1; j <= i + dim; j++) {
|
||||
shape.push_back(std::stoi(split_shape_str[j]));
|
||||
}
|
||||
i += dim;
|
||||
shapes->push_back(shape);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SetOutputsShape(TensorC **outputs, size_t outputs_size, const std::string &outputs_shape_str) {
|
||||
std::vector<std::vector<int>> shapes;
|
||||
if (GetCustomShape(outputs_shape_str, &shapes) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (shapes.size() != outputs_size) {
|
||||
MS_LOG(ERROR) << "The saved outputs is not equal to the outputs_size: " << shapes.size() << " vs " << outputs_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < outputs_size; i++) {
|
||||
if (shapes[i].size() > MAX_SHAPE_SIZE) {
|
||||
MS_LOG(ERROR) << "The output shape size " << shapes.size() << " is greater than max size " << MAX_SHAPE_SIZE;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t j = 0; j < shapes[i].size(); j++) {
|
||||
outputs[i]->shape_[j] = shapes[i][j];
|
||||
}
|
||||
outputs[i]->shape_size_ = shapes[i].size();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SetOutputsFormat(TensorC **outputs, size_t outputs_size, const std::string &output_format_str) {
|
||||
auto formats = SplitString(output_format_str, ',');
|
||||
if (formats.size() != outputs_size) {
|
||||
MS_LOG(ERROR) << "The saved outputs is not equal to the outputs_size: " << formats.size() << " vs " << outputs_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < formats.size(); i++) {
|
||||
outputs[i]->format_ = std::stoi(formats[i]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SetOutputsType(TensorC **outputs, size_t outputs_size, const std::string &output_type_str) {
|
||||
auto types = SplitString(output_type_str, ',');
|
||||
if (types.size() != outputs_size) {
|
||||
MS_LOG(ERROR) << "The saved outputs is not equal to the outputs_size: " << types.size() << " vs " << outputs_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < types.size(); i++) {
|
||||
outputs[i]->data_type_ = std::stoi(types[i]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
auto param = reinterpret_cast<GraphKernelParameter *>(parameter);
|
||||
auto prim = static_cast<schema::Primitive *>(param->prim_)->value_as_Custom();
|
||||
for (size_t i = 0; i < prim->attr()->size(); i++) {
|
||||
auto attr = prim->attr()->Get(i);
|
||||
if (attr->name()->str() == "outputs_shape") {
|
||||
std::string data(reinterpret_cast<const char *>(attr->data()->Data()), attr->data()->size());
|
||||
if (SetOutputsShape(outputs, outputs_size, data) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (attr->name()->str() == "outputs_format") {
|
||||
std::string data(reinterpret_cast<const char *>(attr->data()->Data()), attr->data()->size());
|
||||
if (SetOutputsFormat(outputs, outputs_size, data) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (attr->name()->str() == "outputs_type") {
|
||||
std::string data(reinterpret_cast<const char *>(attr->data()->Data()), attr->data()->size());
|
||||
if (SetOutputsType(outputs, outputs_size, data) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int GraphKernelInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
return mindspore::graphkernel::InferShape(inputs, inputs_size, outputs, outputs_size, parameter);
|
||||
}
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
REG_INFER(GraphKernel, PrimType_Inner_GraphKernel, GraphKernelInferShape)
|
|
@ -0,0 +1,135 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/lite_adapter/runtime/akg_kernel.h"
|
||||
#include <dlfcn.h>
|
||||
#include <algorithm>
|
||||
#include "src/tensor.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "schema/model_generated.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
constexpr auto kAkgKernelSo = "akgkernels.so";
|
||||
namespace {
|
||||
int TmpAkgParallelLaunchFunc(AkgParallelLambda flambda, void *cdata, int num_task) {
|
||||
/*
|
||||
The `cdata` is a second-level pointer, which first element is a pointer to a structure object.
|
||||
The structure contains original AkgCallBack's elements, but except the first `parallel_launch_func`.
|
||||
It seems like `{malloc_func, free_func, extend_data}`, all elements are also pointers.
|
||||
So, to get the `extend_data`, we can treat the `cdata` as a third-level pointer,
|
||||
and then offset TWO elements for the first structure object.
|
||||
The `extend_data` was set as the `this` pointer of `AkgKernel` object.
|
||||
*/
|
||||
const auto kExtendDataOffset = 2;
|
||||
void *extend_data = static_cast<void ***>(cdata)[0][kExtendDataOffset];
|
||||
static_cast<AkgKernel *>(extend_data)->AkgParallelLaunchFunc(flambda, cdata, num_task);
|
||||
return 0;
|
||||
}
|
||||
|
||||
class AkgCallBack {
|
||||
public:
|
||||
void *parallel_launch_func = nullptr;
|
||||
void *(*malloc_func)(size_t) = nullptr;
|
||||
void (*free_func)(void *) = nullptr;
|
||||
void *extend_data = nullptr;
|
||||
|
||||
AkgCallBack() {
|
||||
parallel_launch_func = reinterpret_cast<void *>(TmpAkgParallelLaunchFunc);
|
||||
malloc_func = &malloc;
|
||||
free_func = &free;
|
||||
}
|
||||
~AkgCallBack() = default;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void AkgKernel::ExtractKernelName() {
|
||||
auto prim = static_cast<schema::Primitive *>(params_->prim_)->value_as_Custom();
|
||||
for (size_t i = 0; i < prim->attr()->size(); i++) {
|
||||
auto attr = prim->attr()->Get(i);
|
||||
if (attr->name()->str() == "kernel_name") {
|
||||
auto data = attr->data();
|
||||
kernel_name_ = std::string(reinterpret_cast<const char *>(data->Data()), data->size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AkgKernel::~AkgKernel() {
|
||||
if (handle_ != nullptr) {
|
||||
(void)dlclose(handle_);
|
||||
}
|
||||
}
|
||||
|
||||
int TmpDoTask(void *obj, int task_id, float lhs_scale, float rhs_scale) {
|
||||
return static_cast<AkgKernel *>(obj)->DoTask(task_id, lhs_scale, rhs_scale);
|
||||
}
|
||||
|
||||
int AkgKernel::DoTask(int task_id, float, float) {
|
||||
(void)cached_akg_lambda_(task_id, nthread_, cached_runtimeargs_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void AkgKernel::AkgParallelLaunchFunc(AkgParallelLambda flambda, void *cdata, int) {
|
||||
cached_akg_lambda_ = flambda;
|
||||
cached_runtimeargs_ = cdata;
|
||||
(void)ParallelLaunch(this->ms_context_, TmpDoTask, this, this->nthread_);
|
||||
cached_akg_lambda_ = nullptr;
|
||||
cached_runtimeargs_ = nullptr;
|
||||
}
|
||||
|
||||
int AkgKernel::Prepare() {
|
||||
if (handle_ != nullptr || kernel_func_ != nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
handle_ = dlopen(kAkgKernelSo, RTLD_LAZY | RTLD_LOCAL);
|
||||
if (handle_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Load [" << kAkgKernelSo << "] failed. kernel: [" << kernel_name_ << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernel_func_ = dlsym(handle_, kernel_name_.c_str());
|
||||
if (kernel_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Undefined symbol [" << kernel_name_ << "] in [" << kAkgKernelSo << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AkgKernel::Run() {
|
||||
if (kernel_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Kernel function [" << kernel_name_ << "] is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
nthread_ = op_parameter_->thread_num_;
|
||||
std::vector<void *> runtimeargs;
|
||||
runtimeargs.reserve(in_tensors_.size() + out_tensors_.size() + 1);
|
||||
AkgCallBack akg_callback;
|
||||
akg_callback.extend_data = static_cast<void *>(this);
|
||||
(void)runtimeargs.emplace_back(static_cast<void *>(&akg_callback));
|
||||
(void)std::transform(std::begin(in_tensors_), std::end(in_tensors_), std::back_inserter(runtimeargs),
|
||||
[](lite::Tensor *input) { return input->data(); });
|
||||
(void)std::transform(std::begin(out_tensors_), std::end(out_tensors_), std::back_inserter(runtimeargs),
|
||||
[](lite::Tensor *output) { return output->MutableData(); });
|
||||
using AkgCpuKernelFunction = void (*)(void *);
|
||||
reinterpret_cast<AkgCpuKernelFunction>(kernel_func_)(static_cast<void *>(runtimeargs.data()));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_GraphKernel, LiteKernelCreator<AkgKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_LITE_ADAPTER_RUNTIME_AKG_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_LITE_ADAPTER_RUNTIME_AKG_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "src/inner_kernel.h"
|
||||
#include "common/graph_kernel/lite_adapter/common/graph_kernel_op_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
using AkgParallelLambda = int (*)(int task_id, int num_task, void *cdata);
|
||||
|
||||
class AkgKernel : public InnerKernel {
|
||||
public:
|
||||
AkgKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
params_ = reinterpret_cast<GraphKernelParameter *>(op_parameter_);
|
||||
ExtractKernelName();
|
||||
}
|
||||
~AkgKernel() override;
|
||||
|
||||
int Prepare() override;
|
||||
int Run() override;
|
||||
int ReSize() override {
|
||||
// donot support ReSize now.
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
|
||||
// the real callback function that send to akg
|
||||
void AkgParallelLaunchFunc(AkgParallelLambda flambda, void *cdata, int);
|
||||
// the callback function that send to thread pool
|
||||
int DoTask(int task_id, float, float);
|
||||
|
||||
protected:
|
||||
void ExtractKernelName();
|
||||
|
||||
GraphKernelParameter *params_{nullptr};
|
||||
void *handle_{nullptr};
|
||||
void *kernel_func_{nullptr};
|
||||
std::string kernel_name_;
|
||||
int nthread_{0};
|
||||
AkgParallelLambda cached_akg_lambda_ = nullptr;
|
||||
void *cached_runtimeargs_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_LITE_ADAPTER_RUNTIME_AKG_KERNEL_H_
|
|
@ -225,6 +225,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
|||
"../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/graph_kernel_optimization.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
"../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/graph_kernel_pass_manager.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/runtime/akg_kernel.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/common/infer_shape.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_compile.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_mod.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.cc")
|
||||
|
|
Loading…
Reference in New Issue