[MSLITE] runtime pass bug

This commit is contained in:
ling 2021-08-17 11:37:11 +08:00
parent a34d737858
commit 877ea4a394
6 changed files with 143 additions and 56 deletions

View File

@ -69,6 +69,17 @@ bool VectorErase(std::vector<T> *vec, T element) {
return ret;
}
template <typename T>
bool VectorSetNull(std::vector<T> *vec, T element) {
bool ret = false;
for (size_t i = 0; i < vec->size(); i++) {
if (vec->at(i) == element) {
vec->at(i) = nullptr;
}
}
return ret;
}
template <typename T>
bool VectorReplace(std::vector<T> *vec, T srcElement, T dstElement) {
bool ret = false;

View File

@ -738,7 +738,9 @@ LiteSession::~LiteSession() {
kernel = nullptr;
}
for (auto tensor : tensors_) {
MS_ASSERT(tensor != nullptr);
if (tensor == nullptr) {
continue;
}
// Data of const tensor which doesn't own data will not freed.
// Such as const data from meta_graph which will be freed when freeing meta_graph.
if (tensor->IsConst() && !tensor->own_data()) {

View File

@ -21,30 +21,60 @@
namespace mindspore::lite {
void Nc4hw4PassReplace(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors, size_t index) {
kernel::LiteKernel *conv_kernel = kernels->at(index);
kernel::LiteKernel *traspose_kernel = conv_kernel->out_kernels().front();
kernel::LiteKernel *c4_kernel = traspose_kernel->out_kernels().front();
kernel::LiteKernel *transpose_kernel = conv_kernel->out_kernels().front();
kernel::LiteKernel *c4_kernel = transpose_kernel->out_kernels().front();
kernel::LiteKernel *transpose2_kernel = c4_kernel->out_kernels().front();
std::vector<kernel::LiteKernel *> end_kernels = transpose2_kernel->out_kernels();
/* tensor */
Tensor *transpose_param_tensor = traspose_kernel->in_tensors().at(1);
VectorErase(tensors, transpose_param_tensor);
delete transpose_param_tensor;
transpose_param_tensor = nullptr;
{
/* transpose_kernel */
Tensor *transpose_param_tensor = transpose_kernel->in_tensors().at(1);
VectorSetNull(tensors, transpose_param_tensor);
delete transpose_param_tensor;
transpose_param_tensor = nullptr;
Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
conv_out_tensor->set_format(NC4HW4);
Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
c4_kernel->set_in_tensor(conv_out_tensor, 0);
VectorErase(tensors, c4_input_tensor);
delete c4_input_tensor;
c4_input_tensor = nullptr;
Tensor *conv_out_tensor = conv_kernel->out_tensors().front();
conv_out_tensor->set_format(NC4HW4);
Tensor *c4_input_tensor = c4_kernel->in_tensors().front();
c4_kernel->set_in_tensor(conv_out_tensor, 0);
VectorSetNull(tensors, c4_input_tensor);
delete c4_input_tensor;
c4_input_tensor = nullptr;
}
{
/* transpose2_kernel */
Tensor *transpose_param_tensor = transpose2_kernel->in_tensors().at(1);
VectorSetNull(tensors, transpose_param_tensor);
delete transpose_param_tensor;
transpose_param_tensor = nullptr;
Tensor *nwhc_tensor = c4_kernel->out_tensors().front();
nwhc_tensor->set_format(NHWC);
for (auto end : end_kernels) {
end->set_in_tensor(nwhc_tensor, 0);
}
Tensor *trans_out = transpose2_kernel->out_tensors().front();
VectorSetNull(tensors, trans_out);
delete trans_out;
trans_out = nullptr;
}
/* kernel */
VectorErase(kernels, traspose_kernel);
delete traspose_kernel;
traspose_kernel = nullptr;
VectorErase(kernels, transpose_kernel);
delete transpose_kernel;
transpose_kernel = nullptr;
conv_kernel->set_out_kernels({c4_kernel});
c4_kernel->set_in_kernels({conv_kernel});
c4_kernel->set_out_kernels(transpose2_kernel->out_kernels());
for (auto end : end_kernels) {
end->set_in_kernels({c4_kernel});
}
VectorErase(kernels, transpose2_kernel);
delete transpose2_kernel;
transpose2_kernel = nullptr;
return;
}
@ -61,27 +91,38 @@ bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
return false;
}
kernel::LiteKernel *traspose_kernel = start_kernel->out_kernels().front();
if (start_kernel->type() != Nc4hw4FormatTransposeOp) {
kernel::LiteKernel *traspose_nhwc2nchw_kernel = start_kernel->out_kernels().front();
if (traspose_nhwc2nchw_kernel->type() != Nc4hw4FormatTransposeOp) {
return false;
}
if (traspose_kernel->out_kernels().size() != 1) {
if (traspose_nhwc2nchw_kernel->out_kernels().size() != 1) {
return false;
}
kernel::LiteKernel *end_kernel = traspose_kernel->out_kernels().front();
kernel::LiteKernel *end_kernel = traspose_nhwc2nchw_kernel->out_kernels().front();
if (IsContain(Nc4hw4FormatInOpList, end_kernel->type()) == false) {
return false;
}
if (end_kernel->out_kernels().size() != 1) {
return false;
}
kernel::LiteKernel *transpose_nchw2nhwc_kernel = end_kernel->out_kernels().front();
if (transpose_nchw2nhwc_kernel->type() != Nc4hw4FormatTransposeOp) {
return false;
}
/* double check ops topological sorted in kernel-list */
auto start_iter = find(kernels->begin(), kernels->end(), start_kernel);
auto start_index = std::distance(kernels->begin(), start_iter);
auto transpose_iter = find(kernels->begin(), kernels->end(), traspose_kernel);
auto transpose_index = std::distance(kernels->begin(), transpose_iter);
auto traspose_nhwc2nchw_iter = find(kernels->begin(), kernels->end(), traspose_nhwc2nchw_kernel);
auto traspose_nhwc2nchw_index = std::distance(kernels->begin(), traspose_nhwc2nchw_iter);
auto end_iter = find(kernels->begin(), kernels->end(), end_kernel);
auto end_index = std::distance(kernels->begin(), end_iter);
if (start_index > transpose_index || transpose_index > end_index) {
auto transpose_nchw2nhwc_iter = find(kernels->begin(), kernels->end(), transpose_nchw2nhwc_kernel);
auto transpose_nchw2nhwc_index = std::distance(kernels->begin(), transpose_nchw2nhwc_iter);
if (start_index > traspose_nhwc2nchw_index || traspose_nhwc2nchw_index > end_index ||
end_index > transpose_nchw2nhwc_index) {
return false;
}
@ -89,8 +130,6 @@ bool Nc4hw4PassMatch(std::vector<kernel::LiteKernel *> *kernels, size_t index) {
}
bool Nc4hw4PassValid(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels) {
return false;
if (context->IsGpuEnabled() || context->IsNpuEnabled()) {
return false;
}
@ -103,19 +142,19 @@ bool Nc4hw4PassValid(const InnerContext *context, std::vector<kernel::LiteKernel
}
}
}
return true;
return false;
}
void Nc4hw4Pass(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {
void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {
size_t kernel_size = kernels->size();
size_t index = 0;
for (; index + 2 < kernel_size; index++) {
for (; index + 3 < kernel_size; index++) {
kernel::LiteKernel *kernel = kernels->at(index);
if (kernel->subgraph_type() != kernel::kNotSubGraph) {
kernel::SubGraphKernel *subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
std::vector<kernel::LiteKernel *> &particial_nodes = subgraph->nodes();
Nc4hw4Pass(&particial_nodes, tensors);
Nc4hw4PassAct(&particial_nodes, tensors);
}
if (Nc4hw4PassMatch(kernels, index)) {
@ -126,5 +165,12 @@ void Nc4hw4Pass(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *
}
return;
}
void Nc4hw4Pass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
std::vector<Tensor *> *tensors) {
if (Nc4hw4PassValid(context, kernels)) {
Nc4hw4PassAct(kernels, tensors);
}
}
} // namespace mindspore::lite
#endif

View File

@ -27,15 +27,14 @@
namespace mindspore::lite {
/* Nc4hw4 PASS
* before : CONV --(nhwc)-- TRANSPOSE --(nhwc)-- OP
* after : CONV --(nc4hw4)-- OP
* before : --(nhwc)-- CONV --(nhwc)-- TRANSPOSE --(nchw)-- IN --(nchw)-- TRANSPOSE --(nhwc)--
* after : --(nhwc)-- CONV --(nc4hw4)-- IN --(nhwc)--
* */
static const schema::PrimitiveType Nc4hw4FormatTransposeOp = schema::PrimitiveType_Transpose;
static const std::vector<schema::PrimitiveType> Nc4hw4FormatOutOpList = {schema::PrimitiveType_Conv2DFusion};
static const std::vector<schema::PrimitiveType> Nc4hw4FormatInOpList = {schema::PrimitiveType_InstanceNorm,
schema::PrimitiveType_PadFusion};
bool Nc4hw4PassValid(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels);
void Nc4hw4Pass(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors);
static const std::vector<schema::PrimitiveType> Nc4hw4FormatInOpList = {schema::PrimitiveType_InstanceNorm};
void Nc4hw4Pass(const InnerContext *context, std::vector<kernel::LiteKernel *> *kernels,
std::vector<Tensor *> *tensors);
} // namespace mindspore::lite
#endif

View File

@ -150,13 +150,12 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
#endif
}
FindAllInoutKernels(*dst_kernels);
#ifdef ENABLE_RUNTIME_PASS
if (Nc4hw4PassValid(context_, dst_kernels)) {
Nc4hw4Pass(dst_kernels, src_tensors_);
}
Nc4hw4Pass(context_, dst_kernels, src_tensors_);
#endif
FindAllInoutKernels(*dst_kernels);
#ifdef ENABLE_CONTROLFLOW_TENSORLIST
if (IsControlFlowParttern(*dst_kernels)) {
ret = ConstructControlFlowMainGraph(dst_kernels);

View File

@ -59,23 +59,52 @@ void Nc4hw4PassConstruct(std::vector<kernel::LiteKernel *> *kernels, std::vector
transpose_param, &transpose_kernel, nullptr);
kernels->push_back(transpose_kernel);
lite::Tensor *pad_param_tensor = new lite::Tensor();
tensors->push_back(pad_param_tensor);
lite::Tensor *pad_out_tensor = new lite::Tensor();
tensors->push_back(pad_out_tensor);
OpParameter *pad_param = new OpParameter();
kernel::KernelKey pad_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_PadFusion};
kernel::LiteKernel *pad_kernel = nullptr;
std::vector<lite::Tensor *> pad_in = {transpose_out_tensor, pad_param_tensor};
std::vector<lite::Tensor *> pad_out = {pad_out_tensor};
lite::KernelRegistry::GetInstance()->GetKernel(pad_in, pad_out, ctx, nullptr, pad_desc, pad_param, &pad_kernel,
nullptr);
kernels->push_back(pad_kernel);
lite::Tensor *in_param_tensor = new lite::Tensor();
tensors->push_back(in_param_tensor);
lite::Tensor *in_out_tensor = new lite::Tensor();
tensors->push_back(in_out_tensor);
OpParameter *in_param = new OpParameter();
kernel::KernelKey in_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_InstanceNorm};
kernel::LiteKernel *in_kernel = nullptr;
std::vector<lite::Tensor *> in_in = {transpose_out_tensor, in_param_tensor};
std::vector<lite::Tensor *> in_out = {in_out_tensor};
lite::KernelRegistry::GetInstance()->GetKernel(in_in, in_out, ctx, nullptr, in_desc, in_param, &in_kernel, nullptr);
kernels->push_back(in_kernel);
lite::Tensor *transpose2_param_tensor = new lite::Tensor();
tensors->push_back(transpose_param_tensor);
lite::Tensor *transpose2_out_tensor = new lite::Tensor();
tensors->push_back(transpose_param_tensor);
OpParameter *transpose2_param = new OpParameter();
kernel::KernelKey transpose2_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Transpose};
kernel::LiteKernel *transpose2_kernel = nullptr;
std::vector<lite::Tensor *> transpose2_in = {in_out_tensor, transpose2_param_tensor};
std::vector<lite::Tensor *> transpose2_out = {transpose2_out_tensor};
lite::KernelRegistry::GetInstance()->GetKernel(transpose2_in, transpose2_out, ctx, nullptr, transpose2_desc,
transpose2_param, &transpose2_kernel, nullptr);
kernels->push_back(transpose2_kernel);
lite::Tensor *conv2_weight = new lite::Tensor();
tensors->push_back(conv2_weight);
lite::Tensor *conv2_out_tensor = new lite::Tensor();
tensors->push_back(conv2_out_tensor);
std::vector<lite::Tensor *> conv2_in = {transpose2_out_tensor, conv_weight};
std::vector<lite::Tensor *> conv2_out = {conv2_out_tensor};
OpParameter *conv2_param = new OpParameter();
kernel::KernelKey conv2_desc{kernel::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Conv2DFusion};
kernel::LiteKernel *conv2_kernel = nullptr;
lite::KernelRegistry::GetInstance()->GetKernel(conv2_in, conv2_out, ctx, nullptr, conv2_desc, conv2_param,
&conv2_kernel, nullptr);
kernels->push_back(conv2_kernel);
conv_kernel->set_out_kernels({transpose_kernel});
transpose_kernel->set_in_kernels({conv_kernel});
transpose_kernel->set_out_kernels({pad_kernel});
pad_kernel->set_in_kernels({transpose_kernel});
transpose_kernel->set_out_kernels({in_kernel});
in_kernel->set_in_kernels({transpose_kernel});
in_kernel->set_out_kernels({transpose2_kernel});
transpose2_kernel->set_in_kernels({in_kernel});
transpose2_kernel->set_out_kernels({conv2_kernel});
conv2_kernel->set_in_kernels({transpose2_kernel});
return;
}
@ -85,11 +114,12 @@ TEST_F(RuntimePass, Nc4hw4Pass1) {
std::vector<lite::Tensor *> tensors;
Nc4hw4PassConstruct(&kernels, &tensors, ctx.get());
ASSERT_EQ(kernels.size(), 5);
/* runtime pass */
lite::Nc4hw4PassReplace(&kernels, &tensors, 0);
ASSERT_EQ(kernels.size(), 2);
ASSERT_EQ(tensors.size(), 5);
ASSERT_EQ(kernels.size(), 3);
for (auto tensor : tensors) {
delete tensor;