fix anf exporter

This commit is contained in:
yankai 2020-08-03 11:21:02 +08:00
parent 488d1904b6
commit 15d0365748
15 changed files with 235 additions and 21 deletions

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@
#include "mindspore/core/ir/primitive.h"
#include "src/ir/primitive_t_value.h"
#include "base/core_ops.h"
#include "src/ir/tensor.h"
namespace mindspore::lite {
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
@ -223,9 +224,28 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(paramTensor));
} else if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
auto paramTensor = std::make_unique<schema::TensorT>();
auto value = valueNode->value();
if (value->isa<lite::tensor::Tensor>()) {
auto valueAbstract = valueNode->abstract();
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
auto typePtr = abstractTensor->element()->GetTypeTrack();
paramTensor->dataType = typePtr->type_id();
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
paramTensor->nodeType = schema::NodeType_ValueNode;
auto data = value->cast<lite::tensor::TensorPtr>();
paramTensor->data.resize(data->Size());
memcpy(paramTensor->data.data(), data->Data(), data->Size());
nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(paramTensor));
} else {
MS_LOG(ERROR) << "Not support value type , need add support.";
}
}
}
if (isGraphInput) {
graphInputNodes.emplace_back(fbNode);
}

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive->value.value = attr.release();
return 0;
}
AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
} // namespace mindspore::lite

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
public:
AnfDepwiseconv2DPopulater() = default;
~AnfDepwiseconv2DPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
node->primitive->value.value = attr.release();
return 0;
}
AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater());
} // namespace mindspore::lite

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H
#define MINDSPORE_ANF_DEQUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfDequantPopulater : public AnfNodePopulater {
public:
AnfDequantPopulater() = default;
~AnfDequantPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_ANF_DEQUANT_PARSER_H

View File

@ -32,4 +32,5 @@ int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema
return 0;
}
AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater());
AnfNodePopulaterRegistrar anfMatMulParser("MatMul", new AnfMulPopulater());
} // namespace mindspore::lite

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,12 +26,6 @@ namespace mindspore {
namespace lite {
AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() {
static AnfNodePopulaterRegistry instance;
instance.SetNodePopulater("BiasAdd", new AnfBiasAddPopulater());
instance.SetNodePopulater("Conv2D", new AnfConvPopulater());
instance.SetNodePopulater("MatMul", new AnfMatmulPopulater());
instance.SetNodePopulater("MaxPool", new AnfPoolPopulater());
instance.SetNodePopulater("ReLU", new AnfActivationPopulater());
instance.SetNodePopulater("Flatten", new AnfFlattenPopulater());
return &instance;
}
AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_quant_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfQuantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto attr = std::make_unique<schema::OnnxInt8QuantizeT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
node->primitive->value.value = attr.release();
return 0;
}
AnfNodePopulaterRegistrar anfQuantParser("Quant", new AnfQuantPopulater());
} // namespace mindspore::lite

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_QUANT_PARSER_H
#define MINDSPORE_ANF_QUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfQuantPopulater : public AnfNodePopulater {
public:
AnfQuantPopulater() = default;
~AnfQuantPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_ANF_QUANT_PARSER_H

View File

@ -123,7 +123,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
MS_LOG(ERROR) << "new weightFormatPass failed";
return RET_ERROR;
}
// weightFormatPass->SetQuantType(ctx.quantType);
weightFormatPass->SetQuantType(ctx.quantType);
weightFormatPass->SetFmkType(ctx.fmk);
weightFormatOptimizer.AddPass(weightFormatPass);
status = weightFormatOptimizer.Run(graphDefT);
@ -141,7 +141,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_ERROR;
}
// formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());

View File

@ -191,7 +191,7 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode);
}
// void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; }

View File

@ -33,7 +33,7 @@ class FormatTransPass : public GraphPass {
STATUS Run(schema::MetaGraphT *graph) override;
// void SetQuantType(QuantType quantType);
void SetQuantType(QuantType quantType);
void SetFmk(converter::FmkType fmkType);

View File

@ -43,7 +43,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
return 0;
}
// void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; }
@ -223,11 +223,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = subGraph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == -22); // DataType_DT_FLOAT
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
STATUS status;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
@ -237,7 +237,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
}
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKHWC2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
@ -259,7 +259,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
if (weightTensor->format == schema::Format_CKHW) { // from caffe
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK);
@ -272,11 +272,17 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
}
} else if (weightTensor->format == schema::Format_KCHW) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
}
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;

View File

@ -29,7 +29,7 @@ class WeightFormatPass : public NodePass {
~WeightFormatPass() override = default;
// void SetQuantType(QuantType quantType);
void SetQuantType(QuantType quantType);
void SetFmkType(converter::FmkType fmkType);