forked from mindspore-Ecosystem/mindspore
[MSLITE] runtime pass bug
This commit is contained in:
parent
a34d737858
commit
877ea4a394
|
@ -69,6 +69,17 @@ bool VectorErase(std::vector<T> *vec, T element) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool VectorSetNull(std::vector<T> *vec, T element) {
|
||||
bool ret = false;
|
||||
for (size_t i = 0; i < vec->size(); i++) {
|
||||
if (vec->at(i) == element) {
|
||||
vec->at(i) = nullptr;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool VectorReplace(std::vector<T> *vec, T srcElement, T dstElement) {
|
||||
bool ret = false;
|
||||
|
|
|
@ -738,7 +738,9 @@ LiteSession::~LiteSession() {
|
|||
kernel = nullptr;
|
||||
}
|
||||
for (auto tensor : tensors_) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
if (tensor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// Data of const tensor which doesn't own data will not freed.
|
||||
// Such as const data from meta_graph which will be freed when freeing meta_graph.
|
||||
if (tensor->IsConst() && !tensor->own_data()) {
|
||||
|
|
|
@ -21,30 +21,60 @@
|
|||
namespace mindspore::lite {
|
||||
void Nc4hw4PassReplace(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors, size_t index) {
|
||||
kernel::LiteKernel *conv_kernel = kernels->at(index);
|
||||
kernel::LiteKernel *traspose_kernel = conv_kernel->out_kernels().front();
|
||||
kernel::LiteKernel *c4_kernel = traspose_kernel->out_kernels().front();
|
||||
kernel::LiteKernel *transpose_kernel = conv_kernel->out_kernels().front();
|
||||
kernel::LiteKernel *c4_kernel = transpose_kernel->out_kernels().front();
|
||||
kernel::LiteKernel *transpose2_kernel = c4_kernel->out_kernels().front();
|
||||
std::vector<kernel::LiteKernel *> end_kernels = transpose2_kernel->out_kernels();
|
||||
|
||||
/* tensor */
|
||||
Tensor *transpose_param_tensor = traspose_kernel->in_tensors().at(1);
|
||||
VectorErase(tensors, transpose_param_tensor);
|
||||
delete transpose_param_tensor;
|
||||
transpose_param_tensor = nullptr;
|
||||
{
|
||||
/* transpose_kernel */
|
||||
Tensor *transpose_param_tensor = transpose_kernel->in_tensors().at(1);
|
||||
VectorSetNull(tensors, transpose_param_tensor);
|
||||
delete transpose_param_tensor;
|
||||
transpose_param_tensor = nullptr;
|
||||
|
||||
Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
|
||||
conv_out_tensor->set_format(NC4HW4);
|
||||
Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
|
||||
c4_kernel->set_in_tensor(conv_out_tensor, 0);
|
||||
VectorErase(tensors, c4_input_tensor);
|
||||
delete c4_input_tensor;
|
||||
c4_input_tensor = nullptr;
|
||||
Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
|
||||
conv_out_tensor->set_format(NC4HW4);
|
||||
Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
|
||||
c4_kernel->set_in_tensor(conv_out_tensor, 0);
|
||||
VectorSetNull(tensors, c4_input_tensor);
|
||||
delete c4_input_tensor;
|
||||
c4_input_tensor = nullptr;
|
||||
}
|
||||
{
|
||||
/* transpose2_kernel */
|
||||
Tensor *transpose_param_tensor = transpose2_kernel->in_tensors().at(1);
|
||||
VectorSetNull(tensors, transpose_param_tensor);
|
||||
delete transpose_param_tensor;
|
||||
transpose_param_tensor = nullptr;
|
||||
|
||||
Tensor *nwhc_tensor = c4_kernel->out_tensors().front();
|
||||
nwhc_tensor->set_format(NHWC);
|
||||
for (auto end : end_kernels) {
|
||||
end->set_in_tensor(nwhc_tensor, 0);
|
||||
}
|
||||
Tensor *trans_out = transpose2_kernel->out_tensors().front();
|
||||
VectorSetNull(tensors, trans_out);
|
||||
delete trans_out;
|
||||
trans_out = nullptr;
|
||||
}
|
||||
|
||||
/* kernel */
|
||||
VectorErase(kernels, traspose_kernel);
|
||||
delete traspose_kernel;
|
||||
traspose_kernel = nullptr;
|
||||
VectorErase(kernels, transpose_kernel);
|
||||
delete transpose_kernel;
|
||||
transpose_kernel = nullptr;
|
||||
conv_kernel->set_out_kernels({c4_kernel});
|
||||
c4_kernel->set_in_kernels({conv_kernel});
|
||||
|
||||
c4_kernel->set_out_kernels(transpose2_kernel->out_kernels());
|
||||
for (auto end : end_kernels) {
|
||||
end->set_in_kernels({c4_kernel});
|
||||
}
|
||||
VectorErase(kernels, transpose2_kernel);
|
||||
delete transpose2_kernel;
|
||||
transpose2_kernel = nullptr;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -61,27 +91,38 @@ bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
|
|||
return false;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *traspose_kernel = start_kernel->out_kernels().front();
|
||||
if (start_kernel->type() != Nc4hw4FormatTransposeOp) {
|
||||
kernel::LiteKernel *traspose_nhwc2nchw_kernel = start_kernel->out_kernels().front();
|
||||
if (traspose_nhwc2nchw_kernel->type() != Nc4hw4FormatTransposeOp) {
|
||||
return false;
|
||||
}
|
||||
if (traspose_kernel->out_kernels().size() != 1) {
|
||||
if (traspose_nhwc2nchw_kernel->out_kernels().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *end_kernel = traspose_kernel->out_kernels().front();
|
||||
kernel::LiteKernel *end_kernel = traspose_nhwc2nchw_kernel->out_kernels().front();
|
||||
if (IsContain(Nc4hw4FormatInOpList, end_kernel->type()) == false) {
|
||||
return false;
|
||||
}
|
||||
if (end_kernel->out_kernels().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *transpose_nchw2nhwc_kernel = end_kernel->out_kernels().front();
|
||||
if (transpose_nchw2nhwc_kernel->type() != Nc4hw4FormatTransposeOp) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* double check ops topological sorted in kernel-list */
|
||||
auto start_iter = find(kernels->begin(), kernels->end(), start_kernel);
|
||||
auto start_index = std::distance(kernels->begin(), start_iter);
|
||||
auto transpose_iter = find(kernels->begin(), kernels->end(), traspose_kernel);
|
||||
auto transpose_index = std::distance(kernels->begin(), transpose_iter);
|
||||
auto traspose_nhwc2nchw_iter = find(kernels->begin(), kernels->end(), traspose_nhwc2nchw_kernel);
|
||||
auto traspose_nhwc2nchw_index = std::distance(kernels->begin(), traspose_nhwc2nchw_iter);
|
||||
auto end_iter = find(kernels->begin(), kernels->end(), end_kernel);
|
||||
auto end_index = std::distance(kernels->begin(), end_iter);
|
||||
if (start_index > transpose_index || transpose_index > end_index) {
|
||||
auto transpose_nchw2nhwc_iter = find(kernels->begin(), kernels->end(), transpose_nchw2nhwc_kernel);
|
||||
auto transpose_nchw2nhwc_index = std::distance(kernels->begin(), transpose_nchw2nhwc_iter);
|
||||
if (start_index > traspose_nhwc2nchw_index || traspose_nhwc2nchw_index > end_index ||
|
||||
end_index > transpose_nchw2nhwc_index) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -89,8 +130,6 @@ bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
|
|||
}
|
||||
|
||||
bool Nc4hw4PassValid(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels) {
|
||||
return false;
|
||||
|
||||
if (context->IsGpuEnabled() || context->IsNpuEnabled()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -103,19 +142,19 @@ bool Nc4hw4PassValid(const InnerContext *context, std::vector<kernel::LiteKernel
|
|||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void Nc4hw4Pass(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {
|
||||
void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {
|
||||
size_t kernel_size = kernels->size();
|
||||
size_t index = 0;
|
||||
for (; index + 2 < kernel_size; index++) {
|
||||
for (; index + 3 < kernel_size; index++) {
|
||||
kernel::LiteKernel *kernel = kernels->at(index);
|
||||
|
||||
if (kernel->subgraph_type() != kernel::kNotSubGraph) {
|
||||
kernel::SubGraphKernel *subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
|
||||
std::vector<kernel::LiteKernel *> &particial_nodes = subgraph->nodes();
|
||||
Nc4hw4Pass(&particial_nodes, tensors);
|
||||
Nc4hw4PassAct(&particial_nodes, tensors);
|
||||
}
|
||||
|
||||
if (Nc4hw4PassMatch(kernels, index)) {
|
||||
|
@ -126,5 +165,12 @@ void Nc4hw4Pass(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Nc4hw4Pass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
|
||||
std::vector<Tensor *> *tensors) {
|
||||
if (Nc4hw4PassValid(context, kernels)) {
|
||||
Nc4hw4PassAct(kernels, tensors);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
#endif
|
||||
|
|
|
@ -27,15 +27,14 @@
|
|||
namespace mindspore::lite {
|
||||
|
||||
/* Nc4hw4 PASS
|
||||
* before : CONV --(nhwc)-- TRANSPOSE --(nhwc)-- OP
|
||||
* after : CONV --(nc4hw4)-- OP
|
||||
* before : --(nhwc)-- CONV --(nhwc)-- TRANSPOSE --(nchw)-- IN --(nchw)-- TRANSPOSE --(nhwc)--
|
||||
* after : --(nhwc)-- CONV --(nc4hw4)-- IN --(nhwc)--
|
||||
* */
|
||||
static const schema::PrimitiveType Nc4hw4FormatTransposeOp = schema::PrimitiveType_Transpose;
|
||||
static const std::vector<schema::PrimitiveType> Nc4hw4FormatOutOpList = {schema::PrimitiveType_Conv2DFusion};
|
||||
static const std::vector<schema::PrimitiveType> Nc4hw4FormatInOpList = {schema::PrimitiveType_InstanceNorm,
|
||||
schema::PrimitiveType_PadFusion};
|
||||
bool Nc4hw4PassValid(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels);
|
||||
void Nc4hw4Pass(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors);
|
||||
static const std::vector<schema::PrimitiveType> Nc4hw4FormatInOpList = {schema::PrimitiveType_InstanceNorm};
|
||||
void Nc4hw4Pass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
|
||||
std::vector<Tensor *> *tensors);
|
||||
|
||||
} // namespace mindspore::lite
|
||||
#endif
|
||||
|
|
|
@ -150,13 +150,12 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
|||
#endif
|
||||
}
|
||||
|
||||
FindAllInoutKernels(*dst_kernels);
|
||||
|
||||
#ifdef ENABLE_RUNTIME_PASS
|
||||
if (Nc4hw4PassValid(context_, dst_kernels)) {
|
||||
Nc4hw4Pass(dst_kernels, src_tensors_);
|
||||
}
|
||||
Nc4hw4Pass(context_, dst_kernels, src_tensors_);
|
||||
#endif
|
||||
|
||||
FindAllInoutKernels(*dst_kernels);
|
||||
#ifdef ENABLE_CONTROLFLOW_TENSORLIST
|
||||
if (IsControlFlowParttern(*dst_kernels)) {
|
||||
ret = ConstructControlFlowMainGraph(dst_kernels);
|
||||
|
|
|
@ -59,23 +59,52 @@ void Nc4hw4PassConstruct(std::vector<kernel::LiteKernel *> *kernels, std::vector
|
|||
transpose_param, &transpose_kernel, nullptr);
|
||||
kernels->push_back(transpose_kernel);
|
||||
|
||||
lite::Tensor *pad_param_tensor = new lite::Tensor();
|
||||
tensors->push_back(pad_param_tensor);
|
||||
lite::Tensor *pad_out_tensor = new lite::Tensor();
|
||||
tensors->push_back(pad_out_tensor);
|
||||
OpParameter *pad_param = new OpParameter();
|
||||
kernel::KernelKey pad_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_PadFusion};
|
||||
kernel::LiteKernel *pad_kernel = nullptr;
|
||||
std::vector<lite::Tensor *> pad_in = {transpose_out_tensor, pad_param_tensor};
|
||||
std::vector<lite::Tensor *> pad_out = {pad_out_tensor};
|
||||
lite::KernelRegistry::GetInstance()->GetKernel(pad_in, pad_out, ctx, nullptr, pad_desc, pad_param, &pad_kernel,
|
||||
nullptr);
|
||||
kernels->push_back(pad_kernel);
|
||||
lite::Tensor *in_param_tensor = new lite::Tensor();
|
||||
tensors->push_back(in_param_tensor);
|
||||
lite::Tensor *in_out_tensor = new lite::Tensor();
|
||||
tensors->push_back(in_out_tensor);
|
||||
OpParameter *in_param = new OpParameter();
|
||||
kernel::KernelKey in_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_InstanceNorm};
|
||||
kernel::LiteKernel *in_kernel = nullptr;
|
||||
std::vector<lite::Tensor *> in_in = {transpose_out_tensor, in_param_tensor};
|
||||
std::vector<lite::Tensor *> in_out = {in_out_tensor};
|
||||
lite::KernelRegistry::GetInstance()->GetKernel(in_in, in_out, ctx, nullptr, in_desc, in_param, &in_kernel, nullptr);
|
||||
kernels->push_back(in_kernel);
|
||||
|
||||
lite::Tensor *transpose2_param_tensor = new lite::Tensor();
|
||||
tensors->push_back(transpose_param_tensor);
|
||||
lite::Tensor *transpose2_out_tensor = new lite::Tensor();
|
||||
tensors->push_back(transpose_param_tensor);
|
||||
OpParameter *transpose2_param = new OpParameter();
|
||||
kernel::KernelKey transpose2_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Transpose};
|
||||
kernel::LiteKernel *transpose2_kernel = nullptr;
|
||||
std::vector<lite::Tensor *> transpose2_in = {in_out_tensor, transpose2_param_tensor};
|
||||
std::vector<lite::Tensor *> transpose2_out = {transpose2_out_tensor};
|
||||
lite::KernelRegistry::GetInstance()->GetKernel(transpose2_in, transpose2_out, ctx, nullptr, transpose2_desc,
|
||||
transpose2_param, &transpose2_kernel, nullptr);
|
||||
kernels->push_back(transpose2_kernel);
|
||||
|
||||
lite::Tensor *conv2_weight = new lite::Tensor();
|
||||
tensors->push_back(conv2_weight);
|
||||
lite::Tensor *conv2_out_tensor = new lite::Tensor();
|
||||
tensors->push_back(conv2_out_tensor);
|
||||
std::vector<lite::Tensor *> conv2_in = {transpose2_out_tensor, conv_weight};
|
||||
std::vector<lite::Tensor *> conv2_out = {conv2_out_tensor};
|
||||
OpParameter *conv2_param = new OpParameter();
|
||||
kernel::KernelKey conv2_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2DFusion};
|
||||
kernel::LiteKernel *conv2_kernel = nullptr;
|
||||
lite::KernelRegistry::GetInstance()->GetKernel(conv2_in, conv2_out, ctx, nullptr, conv2_desc, conv2_param,
|
||||
&conv2_kernel, nullptr);
|
||||
kernels->push_back(conv2_kernel);
|
||||
|
||||
conv_kernel->set_out_kernels({transpose_kernel});
|
||||
transpose_kernel->set_in_kernels({conv_kernel});
|
||||
transpose_kernel->set_out_kernels({pad_kernel});
|
||||
pad_kernel->set_in_kernels({transpose_kernel});
|
||||
transpose_kernel->set_out_kernels({in_kernel});
|
||||
in_kernel->set_in_kernels({transpose_kernel});
|
||||
in_kernel->set_out_kernels({transpose2_kernel});
|
||||
transpose2_kernel->set_in_kernels({in_kernel});
|
||||
transpose2_kernel->set_out_kernels({conv2_kernel});
|
||||
conv2_kernel->set_in_kernels({transpose2_kernel});
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -85,11 +114,12 @@ TEST_F(RuntimePass, Nc4hw4Pass1) {
|
|||
std::vector<lite::Tensor *> tensors;
|
||||
Nc4hw4PassConstruct(&kernels, &tensors, ctx.get());
|
||||
|
||||
ASSERT_EQ(kernels.size(), 5);
|
||||
|
||||
/* runtime pass */
|
||||
lite::Nc4hw4PassReplace(&kernels, &tensors, 0);
|
||||
|
||||
ASSERT_EQ(kernels.size(), 2);
|
||||
ASSERT_EQ(tensors.size(), 5);
|
||||
ASSERT_EQ(kernels.size(), 3);
|
||||
|
||||
for (auto tensor : tensors) {
|
||||
delete tensor;
|
||||
|
|
Loading…
Reference in New Issue