forked from OSSInnovation/mindspore
fix anf exporter
This commit is contained in:
parent
488d1904b6
commit
15d0365748
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,7 @@
|
|||
#include "mindspore/core/ir/primitive.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "src/ir/tensor.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
||||
|
@ -223,9 +224,28 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
|
|||
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
|
||||
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
|
||||
meta_graph->allTensors.emplace_back(std::move(paramTensor));
|
||||
} else if (inputNode->isa<ValueNode>()) {
|
||||
auto valueNode = inputNode->cast<ValueNodePtr>();
|
||||
auto paramTensor = std::make_unique<schema::TensorT>();
|
||||
auto value = valueNode->value();
|
||||
if (value->isa<lite::tensor::Tensor>()) {
|
||||
auto valueAbstract = valueNode->abstract();
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
|
||||
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
||||
paramTensor->dataType = typePtr->type_id();
|
||||
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
||||
paramTensor->nodeType = schema::NodeType_ValueNode;
|
||||
auto data = value->cast<lite::tensor::TensorPtr>();
|
||||
paramTensor->data.resize(data->Size());
|
||||
memcpy(paramTensor->data.data(), data->Data(), data->Size());
|
||||
nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size();
|
||||
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
|
||||
meta_graph->allTensors.emplace_back(std::move(paramTensor));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support value type , need add support.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isGraphInput) {
|
||||
graphInputNodes.emplace_back(fbNode);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
|
||||
std::vector<schema::TensorT *> *outputs) {
|
||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
node->primitive->value.value = attr.release();
|
||||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
|
||||
AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
|
||||
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
namespace mindspore::lite {
|
||||
class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
|
||||
public:
|
||||
AnfDepwiseconv2DPopulater() = default;
|
||||
~AnfDepwiseconv2DPopulater() override = default;
|
||||
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
|
||||
std::vector<schema::TensorT *> *outputs) {
|
||||
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
|
||||
node->primitive->value.value = attr.release();
|
||||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater());
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H
|
||||
#define MINDSPORE_ANF_DEQUANT_PARSER_H
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
namespace mindspore::lite {
|
||||
class AnfDequantPopulater : public AnfNodePopulater {
|
||||
public:
|
||||
AnfDequantPopulater() = default;
|
||||
~AnfDequantPopulater() override = default;
|
||||
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_ANF_DEQUANT_PARSER_H
|
|
@ -32,4 +32,5 @@ int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema
|
|||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater());
|
||||
AnfNodePopulaterRegistrar anfMatMulParser("MatMul", new AnfMulPopulater());
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,12 +26,6 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() {
|
||||
static AnfNodePopulaterRegistry instance;
|
||||
instance.SetNodePopulater("BiasAdd", new AnfBiasAddPopulater());
|
||||
instance.SetNodePopulater("Conv2D", new AnfConvPopulater());
|
||||
instance.SetNodePopulater("MatMul", new AnfMatmulPopulater());
|
||||
instance.SetNodePopulater("MaxPool", new AnfPoolPopulater());
|
||||
instance.SetNodePopulater("ReLU", new AnfActivationPopulater());
|
||||
instance.SetNodePopulater("Flatten", new AnfFlattenPopulater());
|
||||
return &instance;
|
||||
}
|
||||
AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/anf_exporter/anf_populater/anf_quant_populater.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int mindspore::lite::AnfQuantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
|
||||
std::vector<schema::TensorT *> *outputs) {
|
||||
auto attr = std::make_unique<schema::OnnxInt8QuantizeT>();
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize;
|
||||
node->primitive->value.value = attr.release();
|
||||
return 0;
|
||||
}
|
||||
AnfNodePopulaterRegistrar anfQuantParser("Quant", new AnfQuantPopulater());
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_ANF_QUANT_PARSER_H
|
||||
#define MINDSPORE_ANF_QUANT_PARSER_H
|
||||
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
|
||||
#include <vector>
|
||||
namespace mindspore::lite {
|
||||
class AnfQuantPopulater : public AnfNodePopulater {
|
||||
public:
|
||||
AnfQuantPopulater() = default;
|
||||
~AnfQuantPopulater() override = default;
|
||||
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_ANF_QUANT_PARSER_H
|
|
@ -123,7 +123,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
MS_LOG(ERROR) << "new weightFormatPass failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// weightFormatPass->SetQuantType(ctx.quantType);
|
||||
weightFormatPass->SetQuantType(ctx.quantType);
|
||||
weightFormatPass->SetFmkType(ctx.fmk);
|
||||
weightFormatOptimizer.AddPass(weightFormatPass);
|
||||
status = weightFormatOptimizer.Run(graphDefT);
|
||||
|
@ -141,7 +141,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
MS_LOG(ERROR) << "new formatTransPass failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// formatTransPass->SetQuantType(ctx.quantType);
|
||||
formatTransPass->SetQuantType(ctx.quantType);
|
||||
formatTransPass->SetFmk(ctx.fmk);
|
||||
formatTransOptimizer.AddPass(formatTransPass);
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
|
|
|
@ -191,7 +191,7 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
|
|||
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode);
|
||||
}
|
||||
|
||||
// void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
|
||||
void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
|
||||
|
||||
void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; }
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class FormatTransPass : public GraphPass {
|
|||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
// void SetQuantType(QuantType quantType);
|
||||
void SetQuantType(QuantType quantType);
|
||||
|
||||
void SetFmk(converter::FmkType fmkType);
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
|
||||
void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
|
||||
|
||||
void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; }
|
||||
|
||||
|
@ -223,11 +223,11 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
|
|||
auto weightIndex = node->inputIndex.at(1);
|
||||
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
|
||||
auto &weightTensor = subGraph->allTensors[weightIndex];
|
||||
MS_ASSERT(weightTensor->dataType == -22); // DataType_DT_FLOAT
|
||||
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
|
||||
STATUS status;
|
||||
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
|
||||
if (weightTensor->format == schema::Format_KCHW) { // from caffe
|
||||
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
|
||||
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
|
||||
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
|
||||
<< weightTensor->dataType;
|
||||
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
|
||||
|
@ -237,7 +237,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
|
|||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
|
||||
}
|
||||
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
|
||||
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
|
||||
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
|
||||
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKHWC2HWCK);
|
||||
} else {
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
|
||||
|
@ -259,7 +259,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
|
|||
}
|
||||
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
|
||||
if (weightTensor->format == schema::Format_CKHW) { // from caffe
|
||||
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
|
||||
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
|
||||
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
|
||||
weightTensor->dataType;
|
||||
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK);
|
||||
|
@ -272,11 +272,17 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
|
|||
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
|
||||
return 0;
|
||||
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
|
||||
if (weightTensor->dataType == -22) { // DataType_DT_UINT8) {
|
||||
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
|
||||
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK);
|
||||
} else {
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
|
||||
}
|
||||
} else if (weightTensor->format == schema::Format_KCHW) {
|
||||
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
|
||||
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
|
||||
} else {
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||
return -1;
|
||||
|
|
|
@ -29,7 +29,7 @@ class WeightFormatPass : public NodePass {
|
|||
|
||||
~WeightFormatPass() override = default;
|
||||
|
||||
// void SetQuantType(QuantType quantType);
|
||||
void SetQuantType(QuantType quantType);
|
||||
|
||||
void SetFmkType(converter::FmkType fmkType);
|
||||
|
||||
|
|
Loading…
Reference in New Issue