[GraphKernel]transform the conv weight in converter
This commit is contained in:
parent
d9e30fce83
commit
b81878efc6
|
@ -27,15 +27,17 @@
|
|||
#include "common/graph_kernel/model/lite_graph.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
|
||||
namespace mindspore::prim {
|
||||
GVAR_DEF(PrimitivePtr, kPrimFloatStatus, std::make_shared<Primitive>("FloatStatus"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimElemAny, std::make_shared<Primitive>("ElemAny"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLayoutTransform, std::make_shared<Primitive>("LayoutTransform"));
|
||||
} // namespace mindspore::prim
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
|
||||
constexpr auto kAllTarget = "ALL";
|
||||
constexpr auto kOutputsFormat = "outputs_format";
|
||||
|
||||
GVAR_DEF(PrimitivePtr, kPrimFloatStatus, std::make_shared<Primitive>("FloatStatus"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimElemAny, std::make_shared<Primitive>("ElemAny"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLayoutTransform, std::make_shared<Primitive>("LayoutTransform"));
|
||||
|
||||
using OpWithLevel = std::tuple<std::string, unsigned int, PrimitivePtr>;
|
||||
|
||||
class GkUtils {
|
||||
|
|
|
@ -50,7 +50,7 @@ InplaceAssignerInfo SubGraphSignleOutput(const AnfNodePtr &anf_node) {
|
|||
InplaceAssignerInfo new_op_info;
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
|
||||
auto output = sub_graph->output();
|
||||
if (IsPrimitiveCNode(output, kPrimElemAny)) {
|
||||
if (IsPrimitiveCNode(output, prim::kPrimElemAny)) {
|
||||
new_op_info.op_node = output->cast<CNodePtr>();
|
||||
}
|
||||
return new_op_info;
|
||||
|
@ -128,7 +128,7 @@ bool FloatStatusAddNFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
bool pattern_match =
|
||||
std::all_of(cnode->inputs().begin() + 1, cnode->inputs().end(),
|
||||
[](const AnfNodePtr &anf_node) { return IsPrimitiveCNode(anf_node, kPrimFloatStatus); });
|
||||
[](const AnfNodePtr &anf_node) { return IsPrimitiveCNode(anf_node, prim::kPrimFloatStatus); });
|
||||
if (!pattern_match) continue;
|
||||
ProcessFloatStatusAddN(func_graph, cnode, mng);
|
||||
changed = true;
|
||||
|
|
|
@ -22,28 +22,12 @@
|
|||
#include "ir/dtype.h"
|
||||
|
||||
namespace mindspore::graphkernel::expanders {
|
||||
int64_t CalInnerAxisLen(const int64_t channel_outer, const int64_t channel_inner) {
|
||||
if (channel_inner != -1) {
|
||||
return channel_inner;
|
||||
}
|
||||
int64_t simd_size = 8;
|
||||
auto channel = channel_outer;
|
||||
auto inner_len = 1;
|
||||
for (auto inner = simd_size; inner > 0; inner--) {
|
||||
if (channel % inner == 0) {
|
||||
inner_len = inner;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return inner_len;
|
||||
}
|
||||
|
||||
class Conv2DFusion : public OpDesc {
|
||||
public:
|
||||
Conv2DFusion() {
|
||||
std::initializer_list<std::string> attrs{"kernel_size", "out_channel", "stride", "dilation",
|
||||
"in_channel", "pad_list", "pad_mode"};
|
||||
std::initializer_list<std::string> attrs{"kernel_size", "out_channel", "stride", "dilation",
|
||||
"in_channel", "pad_list", "pad_mode", "weight_coo",
|
||||
"weight_coi", "weight_cio", "weight_cii"};
|
||||
(void)validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
|
||||
}
|
||||
~Conv2DFusion() = default;
|
||||
|
@ -54,38 +38,21 @@ class Conv2DFusion : public OpDesc {
|
|||
const auto &weight = inputs[1];
|
||||
auto data_shape = data->shape;
|
||||
auto data_format = data->format;
|
||||
auto weight_shape = weight->shape;
|
||||
|
||||
// pad_top, pad_bottom, pad_left, pad_right
|
||||
std::vector<int64_t> pads = GetValue<std::vector<int64_t>>(attrs_["pad_list"]);
|
||||
auto n = data_shape[0];
|
||||
auto h = data_shape[1];
|
||||
auto w = data_shape[2];
|
||||
auto c_in = data_shape[3];
|
||||
auto c_out = weight_shape[0];
|
||||
auto k_h = weight_shape[1];
|
||||
auto k_w = weight_shape[2];
|
||||
auto c_i_i = GetValue<int64_t>(attrs_["weight_cii"]);
|
||||
auto c_i_o = GetValue<int64_t>(attrs_["weight_cio"]);
|
||||
auto c_o_i = GetValue<int64_t>(attrs_["weight_coi"]);
|
||||
|
||||
auto c_i_i = CalInnerAxisLen(c_in, -1);
|
||||
if (c_i_i == 0) {
|
||||
MS_LOG(EXCEPTION) << "Calculation of Conv2DFusion is wrong, please check.";
|
||||
}
|
||||
auto c_i_o = c_in / c_i_i;
|
||||
ShapeVector data_rs_shape{n, h, w, c_i_o, c_i_i};
|
||||
std::string conv_format = "NCHW" + std::to_string(c_i_i) + "c";
|
||||
auto data_tp = gb.Emit("LayoutTransform", {data},
|
||||
{{"src_format", MakeValue(data_format)}, {"dst_format", MakeValue(conv_format)}});
|
||||
|
||||
auto c_o_i = CalInnerAxisLen(c_out, -1);
|
||||
if (c_o_i == 0) {
|
||||
MS_LOG(EXCEPTION) << "Calculation of Conv2DFusion is wrong, please check.";
|
||||
}
|
||||
auto c_o_o = c_out / c_o_i;
|
||||
ShapeVector weight_rs_shape{c_o_o, c_o_i, k_h, k_w, c_i_o, c_i_i};
|
||||
auto weight_rs = gb.Reshape(weight, weight_rs_shape);
|
||||
ShapeVector weight_perm{0, 4, 2, 3, 5, 1};
|
||||
auto weight_tp = gb.Transpose(weight_rs, weight_perm);
|
||||
|
||||
// PAD: NCHWc->NCHWc
|
||||
auto pad_n = pads[0];
|
||||
auto pad_h = pads[1];
|
||||
|
@ -111,7 +78,7 @@ class Conv2DFusion : public OpDesc {
|
|||
updated_attrs["data_format"] = MakeValue(kOpFormat_NC1HWC0);
|
||||
std::string conv_out_format = "NCHW" + std::to_string(c_o_i) + "c";
|
||||
updated_attrs["conv_out_format"] = MakeValue(conv_out_format);
|
||||
auto result_nchwc = gb.Emit("Conv2D", {data_pad, weight_tp}, updated_attrs);
|
||||
auto result_nchwc = gb.Emit("Conv2D", {data_pad, weight}, updated_attrs);
|
||||
|
||||
inner::NodePtr result_nchwc_act;
|
||||
if (attrs_.find("activation_type") != attrs_.end()) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "tools/graph_kernel/converter/substitute_conv2d.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
AnfNodePtr ParaToValueDeco::Run(const AnfNodePtr &node) {
|
||||
|
@ -43,6 +44,26 @@ AnfNodePtr ParaToValueDeco::Run(const AnfNodePtr &node) {
|
|||
return decorated_->Run(cnode);
|
||||
}
|
||||
|
||||
AnfNodePtr ParaToTensorDeco::Run(const AnfNodePtr &node) {
|
||||
auto cnode = QuickCloneCNode(node);
|
||||
for (const auto &idx : input_idx_) {
|
||||
if (cnode->input(idx + 1)->isa<Parameter>()) {
|
||||
auto default_param = cnode->input(idx + 1)->cast<ParameterPtr>()->default_param();
|
||||
if (default_param == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto param_value = default_param->cast<tensor::TensorPtr>();
|
||||
if (param_value == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto value = NewValueNode(param_value);
|
||||
value->set_abstract(param_value->ToAbstract());
|
||||
cnode->set_input(idx + 1, value);
|
||||
}
|
||||
}
|
||||
return decorated_->Run(cnode);
|
||||
}
|
||||
|
||||
AnfNodePtr FixFormatDeco::Run(const AnfNodePtr &node) {
|
||||
auto cnode = QuickCloneCNode(node);
|
||||
std::vector<std::string> format = {kOpFormat_DEFAULT};
|
||||
|
@ -100,6 +121,7 @@ ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) {
|
|||
{prim::kPrimSqueeze->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimReshape->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimTranspose->name(), {ParaToValueDeco::GetCreator({1}), InputToAttrDeco::GetCreator({1})}},
|
||||
{prim::kPrimConv2DFusion->name(), {ParaToTensorDeco::GetCreator({1}), SubstituteConv2D::Creator}},
|
||||
};
|
||||
auto iter = creators.find(GetCNodePrimitive(node)->name());
|
||||
if (iter != creators.end()) {
|
||||
|
|
|
@ -40,6 +40,23 @@ class ParaToValueDeco : public ExpanderDecorator {
|
|||
HashSet<size_t> input_idx_;
|
||||
};
|
||||
|
||||
class ParaToTensorDeco : public ExpanderDecorator {
|
||||
public:
|
||||
ParaToTensorDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx)
|
||||
: ExpanderDecorator(decorated), input_idx_(input_idx) {}
|
||||
~ParaToTensorDeco() = default;
|
||||
|
||||
static ExpanderCreatorFunc GetCreator(const HashSet<size_t> &input_idx) {
|
||||
return [input_idx](const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<ParaToTensorDeco>(decorated, input_idx));
|
||||
};
|
||||
}
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
HashSet<size_t> input_idx_;
|
||||
};
|
||||
|
||||
class FixFormatDeco : public ExpanderDecorator {
|
||||
public:
|
||||
explicit FixFormatDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/graph_kernel/converter/substitute_conv2d.h"
|
||||
#include <utility>
|
||||
#include "utils/anf_utils.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
constexpr size_t kConv2dDataIndex = 1;
|
||||
constexpr size_t kConv2dWeightIndex = 2;
|
||||
constexpr size_t kWeightChannelOutAxis = 0;
|
||||
constexpr size_t kWeightHeightAxis = 1;
|
||||
constexpr size_t kWeightWidthAxis = 2;
|
||||
constexpr size_t kWeightChannelInAxis = 3;
|
||||
constexpr size_t kShapeRank = 4;
|
||||
|
||||
std::pair<int64_t, int64_t> TilingChannel(int64_t channel) {
|
||||
const int64_t simd_size = 8LL;
|
||||
for (auto inner = simd_size; inner > 0; inner--) {
|
||||
if (channel % inner == 0) {
|
||||
return std::make_pair(channel / inner, inner);
|
||||
}
|
||||
}
|
||||
return {channel, 1LL};
|
||||
}
|
||||
|
||||
class IndexCalc {
|
||||
public:
|
||||
explicit IndexCalc(const ShapeVector &shape) : shape_(shape) {}
|
||||
int64_t GetFlatIndex(const ShapeVector &index) {
|
||||
if (index.size() != shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The index's size should be equal to shape's size, but got " << index.size() << " vs "
|
||||
<< shape_.size();
|
||||
}
|
||||
int64_t prod = 1LL;
|
||||
int64_t result = 0LL;
|
||||
for (int i = SizeToInt(shape_.size()) - 1; i >= 0; i--) {
|
||||
result += index[i] * prod;
|
||||
prod *= shape_[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
ShapeVector shape_;
|
||||
};
|
||||
|
||||
AnfNodePtr SubstituteConv2D::InferWeightValue(const AnfNodePtr &node) {
|
||||
auto cnode = QuickCloneCNode(node);
|
||||
auto prim = GetCNodePrimitive(cnode)->Clone();
|
||||
cnode->set_input(0, NewValueNode(prim));
|
||||
auto cb = Callback::Instance();
|
||||
// the weight should be a 4D tensor of format OHWI
|
||||
auto weight_shape = cb->GetInputShape(cnode, kConv2dWeightIndex - 1);
|
||||
if (weight_shape.size() != kShapeRank) {
|
||||
return nullptr;
|
||||
}
|
||||
auto c_out = weight_shape[kWeightChannelOutAxis];
|
||||
auto c_in = weight_shape[kWeightChannelInAxis];
|
||||
int64_t c_out_o, c_out_i, c_in_o, c_in_i;
|
||||
std::tie(c_out_o, c_out_i) = TilingChannel(c_out);
|
||||
std::tie(c_in_o, c_in_i) = TilingChannel(c_in);
|
||||
prim->AddAttr("weight_coo", MakeValue(c_out_o));
|
||||
prim->AddAttr("weight_coi", MakeValue(c_out_i));
|
||||
prim->AddAttr("weight_cio", MakeValue(c_in_o));
|
||||
prim->AddAttr("weight_cii", MakeValue(c_in_i));
|
||||
|
||||
auto weight_node = cnode->input(kConv2dWeightIndex)->cast<ValueNodePtr>();
|
||||
if (weight_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor = weight_node->value()->cast<tensor::TensorPtr>();
|
||||
if (tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (tensor->data().const_data() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (tensor->data_type() != kNumberTypeFloat32) {
|
||||
return nullptr;
|
||||
}
|
||||
auto h_len = weight_shape[kWeightHeightAxis];
|
||||
auto w_len = weight_shape[kWeightWidthAxis];
|
||||
|
||||
// step 1, reshape the weight, [O,H,W,I] -> [Oo,Oi,H,W,Io,Ii]
|
||||
// step 2, transpose it to [Oo,Io,H,W,Ii,Oi]
|
||||
IndexCalc old_shape_calc({c_out_o, c_out_i, h_len, w_len, c_in_o, c_in_i});
|
||||
ShapeVector new_shape = {c_out_o, c_in_o, h_len, w_len, c_in_i, c_out_i};
|
||||
IndexCalc new_shape_calc(new_shape);
|
||||
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), new_shape);
|
||||
auto new_data = new_tensor->data_c();
|
||||
auto old_data = tensor->data_c();
|
||||
for (int64_t coo = 0; coo < c_out_o; coo++) {
|
||||
for (int64_t cio = 0; cio < c_in_o; cio++) {
|
||||
for (int64_t h = 0; h < h_len; h++) {
|
||||
for (int64_t w = 0; w < w_len; w++) {
|
||||
for (int64_t cii = 0; cii < c_in_i; cii++) {
|
||||
for (int64_t coi = 0; coi < c_out_i; coi++) {
|
||||
auto old_val = static_cast<float *>(old_data)[old_shape_calc.GetFlatIndex({coo, coi, h, w, cio, cii})];
|
||||
static_cast<float *>(new_data)[new_shape_calc.GetFlatIndex({coo, cio, h, w, cii, coi})] = old_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto v = NewValueNode(new_tensor);
|
||||
v->set_abstract(new_tensor->ToAbstract());
|
||||
v->set_kernel_info(weight_node->kernel_info_ptr());
|
||||
cnode->set_input(kConv2dWeightIndex, v);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr SubstituteConv2D::Run(const AnfNodePtr &node) {
|
||||
auto new_node = InferWeightValue(node);
|
||||
if (new_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return ExpanderDecorator::Run(new_node);
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SUBSTITUTE_CONV2D_H_
|
||||
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SUBSTITUTE_CONV2D_H_
|
||||
#include <memory>
|
||||
|
||||
#include "common/graph_kernel/core/graph_kernel_expander.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class SubstituteConv2D : public ExpanderDecorator {
|
||||
public:
|
||||
using ExpanderDecorator::ExpanderDecorator;
|
||||
static ExpanderPtr Creator(const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<SubstituteConv2D>(decorated));
|
||||
}
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
AnfNodePtr InferWeightValue(const AnfNodePtr &node);
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SUBSTITUTE_CONV2D_H_
|
|
@ -17,6 +17,7 @@
|
|||
#include "tools/graph_kernel/runtime/akg_kernel.h"
|
||||
#include <dlfcn.h>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include "src/tensor.h"
|
||||
|
@ -110,13 +111,28 @@ int AkgKernel::Prepare() {
|
|||
MS_LOG(ERROR) << "Undefined symbol [" << kernel_name_ << "] in [" << kAkgKernelSo << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
auto &input = in_tensors_[i];
|
||||
if (input->IsConst() && (reinterpret_cast<size_t>(input->data()) & 0xf) != 0) {
|
||||
auto buffer = static_cast<float *>(input->data());
|
||||
int data_num = input->ElementsNum();
|
||||
std::vector<float> input_align(buffer, buffer + data_num);
|
||||
const_inputs_.emplace(i, input_align);
|
||||
const size_t kAddrAlign = 32;
|
||||
const size_t kAddrAlignMask = 0x1f;
|
||||
const_inputs_.reserve(in_tensors_.size());
|
||||
for (auto &input : in_tensors_) {
|
||||
// the data address should align in 32 bytes.
|
||||
if (input->IsConst() && (reinterpret_cast<size_t>(input->data()) & kAddrAlignMask) != 0) {
|
||||
auto buffer = static_cast<int8_t *>(input->data());
|
||||
int tensor_size = input->Size();
|
||||
if (tensor_size == 0) {
|
||||
MS_LOG(ERROR) << "The tensor \'" << input->tensor_name() << "\' size is 0. kernel: " << kernel_name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int8_t> input_align(tensor_size + kAddrAlign);
|
||||
auto p = input_align.data();
|
||||
while ((reinterpret_cast<size_t>(p) & kAddrAlignMask) != 0) {
|
||||
p++;
|
||||
}
|
||||
(void)std::copy(buffer, buffer + tensor_size, p);
|
||||
(void)const_inputs_.emplace_back(static_cast<void *>(p));
|
||||
(void)const_data_align_cache_.emplace_back(std::move(input_align));
|
||||
} else {
|
||||
(void)const_inputs_.emplace_back(nullptr);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -130,12 +146,12 @@ int AkgKernel::Run() {
|
|||
nthread_ = op_parameter_->thread_num_;
|
||||
std::vector<void *> runtimeargs;
|
||||
runtimeargs.reserve(in_tensors_.size() + out_tensors_.size() + 1);
|
||||
AkgCallBack akg_callback;
|
||||
static AkgCallBack akg_callback;
|
||||
akg_callback.extend_data = static_cast<void *>(this);
|
||||
(void)runtimeargs.emplace_back(static_cast<void *>(&akg_callback));
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
if (const_inputs_.find(i) != const_inputs_.end()) {
|
||||
(void)runtimeargs.emplace_back(reinterpret_cast<void *>(const_inputs_[i].data()));
|
||||
if (const_inputs_[i] != nullptr) {
|
||||
(void)runtimeargs.emplace_back(const_inputs_[i]);
|
||||
} else {
|
||||
(void)runtimeargs.emplace_back(in_tensors_[i]->data());
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_RUNTIME_AKG_KERNEL_H_
|
||||
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_RUNTIME_AKG_KERNEL_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "src/runtime/lite_kernel.h"
|
||||
|
@ -56,7 +55,8 @@ class AkgKernel : public LiteKernel {
|
|||
void *kernel_func_{nullptr};
|
||||
std::string kernel_name_;
|
||||
int nthread_{0};
|
||||
std::map<size_t, std::vector<float>> const_inputs_;
|
||||
std::vector<std::vector<int8_t>> const_data_align_cache_;
|
||||
std::vector<void *> const_inputs_;
|
||||
AkgParallelLambda cached_akg_lambda_ = nullptr;
|
||||
void *cached_runtimeargs_ = nullptr;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue