forked from mindspore-Ecosystem/mindspore
add onnx loop if support
This commit is contained in:
parent
2645ed3c90
commit
401d42a103
|
@ -264,6 +264,7 @@ union PrimitiveType {
|
|||
If,
|
||||
GeLU,
|
||||
Gru,
|
||||
NonZero,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -236,7 +236,8 @@ union PrimitiveType {
|
|||
LpNormalization,
|
||||
DropoutGrad,
|
||||
MaximumGrad,
|
||||
MinimumGrad
|
||||
MinimumGrad,
|
||||
NonZero,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -1241,3 +1241,6 @@ table Merge {
|
|||
table GeLU {
|
||||
approximate : bool = false;
|
||||
}
|
||||
|
||||
table NonZero {
|
||||
}
|
||||
|
|
|
@ -1143,3 +1143,6 @@ table LpNormalization {
|
|||
axis : int;
|
||||
p : int;
|
||||
}
|
||||
|
||||
table NonZero {
|
||||
}
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/nonzero.h"
|
||||
#include <algorithm>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/tensor.h"
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int NonZero::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_NonZero;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_NonZero) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::NonZeroT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
PopulaterQuantParam(prim, inputs);
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int NonZero::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_NonZero();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_NonZero return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateNonZero(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_NonZero, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *NonZeroCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<NonZero>(primitive); }
|
||||
Registry NonZeroRegistry(schema::PrimitiveType_NonZero, NonZeroCreator);
|
||||
#endif
|
||||
template <typename T>
|
||||
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape) {
|
||||
int input_count = inputs[0]->ElementsNum();
|
||||
int input_dim_size = inputs[0]->shape().empty() ? 1 : inputs[0]->shape().size();
|
||||
(*out_shape)[0] = input_dim_size;
|
||||
int nonzero_size = 0;
|
||||
for (int i = 0; i < input_count; i++) {
|
||||
if (static_cast<int>(data[i]) != 0) {
|
||||
nonzero_size++;
|
||||
}
|
||||
}
|
||||
if (nonzero_size == 0) {
|
||||
*out_shape = {};
|
||||
} else {
|
||||
(*out_shape)[1] = nonzero_size / input_dim_size;
|
||||
}
|
||||
}
|
||||
int NonZero::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
MS_ASSERT(inputs_.size() == 1);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->set_format(input->format());
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
if (inputs_.size() == kSingleNum) {
|
||||
auto input_tensor = inputs_.at(0);
|
||||
if (input_tensor->data_c() == nullptr) {
|
||||
MS_LOG(INFO) << "Do infer shape in runtime.";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
switch (input_tensor->data_type()) {
|
||||
case kNumberTypeFloat: {
|
||||
auto data = reinterpret_cast<float *>(input_tensor->MutableData());
|
||||
CalShape<float>(data, inputs_, &out_shape);
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "NonZero weight tensor has unsupported dataType: " << input_tensor->data_type();
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "inputs tensor size invalid.";
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_NONZERO_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_NONZERO_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class NonZero : public PrimitiveC {
|
||||
public:
|
||||
NonZero() = default;
|
||||
~NonZero() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(NonZero, PrimitiveC);
|
||||
explicit NonZero(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_NONZERO_H_
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp32/nonzero_fp32.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_NonZero;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int NonZeroCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int NonZeroCPUKernel::ReSize() { return RET_OK; }
|
||||
int NonZeroCPUKernel::Run() {
|
||||
auto in_tensor = in_tensors_.front();
|
||||
auto out_tensor = out_tensors_.front();
|
||||
auto input_data = reinterpret_cast<float *>(in_tensor->MutableData());
|
||||
auto output_data = reinterpret_cast<int *>(out_tensor->MutableData());
|
||||
auto input_dim_size = in_tensor->shape().size();
|
||||
if (out_tensor->shape().size() != 2) {
|
||||
MS_LOG(ERROR) << "out tensor shape size must be equal to 2!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto non_zero_nums = out_tensor->shape()[1];
|
||||
int non_zero_count = 0;
|
||||
std::vector coordiate_values(in_tensor->shape().size(), 0);
|
||||
for (int i = 0; i < in_tensor->ElementsNum(); i += 1) {
|
||||
if (input_data[i] != 0) {
|
||||
for (size_t j = 0; j < input_dim_size; j++) {
|
||||
output_data[non_zero_count + j * non_zero_nums] = coordiate_values[j];
|
||||
}
|
||||
non_zero_count++;
|
||||
}
|
||||
for (int idx = input_dim_size - 1; idx >= 0; --idx) {
|
||||
if (coordiate_values[idx] != in_tensor->shape()[idx] - 1) {
|
||||
coordiate_values[idx] = coordiate_values[idx] + 1;
|
||||
break;
|
||||
}
|
||||
coordiate_values[idx] = 0;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuNonZeroFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Input opParameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Input context is nullptr!";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx->thread_num_ == 0) {
|
||||
MS_LOG(ERROR) << "context thread num is 0!";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) NonZeroCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new NonZeroCPUKernel fail!";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonZero, CpuNonZeroFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class NonZeroCPUKernel : public LiteKernel {
|
||||
public:
|
||||
NonZeroCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
|
||||
~NonZeroCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
protected:
|
||||
int thread_count_ = 1;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
|
|
@ -100,7 +100,8 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
|
|||
}
|
||||
}
|
||||
|
||||
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) {
|
||||
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
|
||||
config->fmk == lite::converter::FmkType_ONNX) {
|
||||
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
|
||||
graph_pm->AddPass(std::make_shared<opt::IfPass>());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_if_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
lite::PrimitiveC *OnnxIfParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx IfParser";
|
||||
auto attr = std::make_unique<schema::IfT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_If;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxIfParser("If", new OnnxIfParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxIfParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxIfParser() : OnnxNodeParser("If") {}
|
||||
~OnnxIfParser() override = default;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_loop_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
lite::PrimitiveC *OnnxLoopParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx LoopParser";
|
||||
auto attr = std::make_unique<schema::WhileT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_While;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxLoopParser("Loop", new OnnxLoopParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxLoopParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxLoopParser() : OnnxNodeParser("Loop") {}
|
||||
~OnnxLoopParser() override = default;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
|
File diff suppressed because it is too large
Load Diff
|
@ -54,29 +54,60 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
private:
|
||||
STATUS InitOriginModel(const std::string &model_file);
|
||||
STATUS ConvertNodes();
|
||||
STATUS ConvertConstTensors();
|
||||
STATUS ConvertGraphInputs();
|
||||
STATUS ConvertGraphOutputs();
|
||||
STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs);
|
||||
STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
|
||||
const std::string &root_node_name);
|
||||
STATUS ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name);
|
||||
STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map);
|
||||
STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *nodes_map);
|
||||
STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
||||
const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map);
|
||||
STATUS BuildReturnNode(const FuncGraphPtr &func_graph_ptr, const std::vector<AnfNodePtr> &return_inputs);
|
||||
STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor);
|
||||
STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type);
|
||||
STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode);
|
||||
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name);
|
||||
STATUS BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
|
||||
lite::PrimitiveC *primitive_c, std::string loop_name);
|
||||
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode);
|
||||
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
lite::PrimitiveC *primitive_c);
|
||||
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c);
|
||||
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c,
|
||||
const std::string &name);
|
||||
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
|
||||
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
|
||||
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
|
||||
bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node);
|
||||
|
||||
STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
const std::string &root_node_name);
|
||||
STATUS ConvertIfOnnxNode(const onnx::NodeProto &onnx_node, std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
const std::string &root_node_name);
|
||||
STATUS AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
|
||||
const std::string &loop_node_name, std::vector<AnfNodePtr> *body_graph_inputs,
|
||||
int act_output_num);
|
||||
STATUS BuildCondGraph(const FuncGraphPtr &cond_graph, const AnfNodePtr &root_while_node, int inputs_num,
|
||||
const std::string &cond_graph_name);
|
||||
STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
|
||||
const std::string &subgrah_name, const std::string &if_node_name,
|
||||
const std::string &root_node_name);
|
||||
onnx::ModelProto onnx_model_;
|
||||
onnx::GraphProto onnx_graph_;
|
||||
std::unordered_map<std::string, AnfNodePtr> nodes_;
|
||||
FuncGraphPtr func_graph_ptr_ = nullptr;
|
||||
onnx::GraphProto onnx_root_graph_;
|
||||
std::vector<FuncGraphPtr> all_subgraphs_;
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
|
||||
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
|
||||
FuncGraphPtr anf_root_graph_ = nullptr;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_nonzero_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx NonZeroParser";
|
||||
auto attr = std::make_unique<schema::NonZeroT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_NonZero;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxNonZeroParser("NonZero", new OnnxNonZeroParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxNonZeroParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxNonZeroParser() : OnnxNodeParser("NonZero") {}
|
||||
~OnnxNonZeroParser() override = default;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
|
|
@ -96,6 +96,19 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
status = ReplaceIdentity(node, manager);
|
||||
} else if (type == schema::PrimitiveType_TupleGetItem) {
|
||||
status = ReplaceTupleGetItem(node, manager);
|
||||
} else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
(void)Run(sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
(void)Run(sub_func_graph);
|
||||
}
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "remove identity pass is failed.";
|
||||
|
|
|
@ -296,6 +296,45 @@ STATUS OnnxInputAdjustOpPass::AdjustStridedSlice(const FuncGraphPtr &func_graph,
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxInputAdjustOpPass::AdjustResize(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto node = cnode->input(0);
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "cnode input0 is not a valuenode.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(value_node->value() != nullptr);
|
||||
auto primitive_c = value_node->value()->cast<PrimitiveCPtr>();
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "cnode has no primitive_c.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto primitive = primitive_c->primitiveT();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "cnode has no schema::primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (primitive->value.type != schema::PrimitiveType_Resize) {
|
||||
MS_LOG(DEBUG) << "cnode is not cast node.";
|
||||
return RET_OK;
|
||||
}
|
||||
auto value = primitive->value.value;
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto attr = reinterpret_cast<schema::ResizeT *>(value);
|
||||
if (cnode->inputs().size() > 3 &&
|
||||
attr->coordinateTransformMode == schema::CoordinateTransformMode_TF_CROP_AND_RESIZE) {
|
||||
auto new_resize_inputs = cnode->inputs();
|
||||
new_resize_inputs.erase(new_resize_inputs.begin() + 1);
|
||||
cnode->set_inputs(new_resize_inputs);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxInputAdjustOpPass::AdjustConvOrDeConv(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (!CheckInputs(cnode)) {
|
||||
|
|
|
@ -40,6 +40,7 @@ class OnnxInputAdjustOpPass : public Pass {
|
|||
STATUS AdjustConvOrDeConv(const CNodePtr &cnode);
|
||||
STATUS AdjustTile(const CNodePtr &cnode);
|
||||
STATUS AdjustCast(const CNodePtr &cnode);
|
||||
STATUS AdjustResize(const CNodePtr &cnode);
|
||||
STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
|
Loading…
Reference in New Issue