anf exporter fixed

This commit is contained in:
kai00 2020-08-19 09:59:42 +08:00
parent 58523a41fe
commit 5ca7be576b
5 changed files with 37 additions and 5 deletions

View File

@ -387,7 +387,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
} }
meta_graphT->allTensors.emplace_back(msTensor); meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D)
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)) { || IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) {
break; break;
} }
} }

View File

@ -29,7 +29,7 @@
namespace mindspore::lite { namespace mindspore::lite {
void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim, void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const std::unique_ptr<schema::PrimitiveT> &primitive,
const int &group) { const int &group, const std::vector<AnfNodePtr> &inputs) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>(); auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim->GetAttr("data_format")); auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") { if (format == "NCHW") {
@ -66,6 +66,28 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
int channel_mutiplier = 1;
if (prim->GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim->GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
auto inputNode = inputs[kAnfPopulaterOne];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>();
auto abstractBase = paramNode->abstract();
MS_ASSERT(abstractBase != nullptr);
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
MS_ASSERT(abstractTensor != nullptr);
if (abstractTensor->format() == schema::Format_NCHW) {
abstractTensor->set_format(schema::Format_KCHW);
}
}
}
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
} }
@ -214,7 +236,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
int group = GetValue<int>(prim->GetAttr("group")); int group = GetValue<int>(prim->GetAttr("group"));
if (group > 1) { if (group > 1) {
PopulaterConv2DMultiGroup(prim, primitive, group); PopulaterConv2DMultiGroup(prim, primitive, group, inputs);
} else { } else {
PopulaterConv2DSingleGroup(prim, primitive, group); PopulaterConv2DSingleGroup(prim, primitive, group);
} }

View File

@ -35,7 +35,7 @@ class AnfConvPopulater : public AnfNodePopulater {
private: private:
void PopulaterConv2DMultiGroup( void PopulaterConv2DMultiGroup(
const PrimitivePtr &prim, const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group); const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group, const std::vector<AnfNodePtr> &inputs);
void PopulaterConv2DSingleGroup( void PopulaterConv2DSingleGroup(
const PrimitivePtr &prim, const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group); const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);

View File

@ -1129,7 +1129,12 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
inputs.clear(); inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimReturn)); auto primReturn = std::make_unique<schema::PrimitiveT>();
MS_ASSERT(primReturn != nullptr);
primReturn->value.type = schema::PrimitiveType_Return;
std::shared_ptr<PrimitiveTValue> primitiveTReturnValuePtr = std::make_shared<PrimitiveTValue>(primReturn.release());
MS_ASSERT(primitiveTReturnValuePtr != nullptr);
inputs.push_back(NewValueNode(primitiveTReturnValuePtr));
inputs.push_back(cnode_ptr); inputs.push_back(cnode_ptr);
auto return_node = outputFuncGraph->NewCNode(inputs); auto return_node = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(return_node); MS_EXCEPTION_IF_NULL(return_node);

View File

@ -18,6 +18,7 @@
#include "tools/common/converter_op_utils.h" #include "tools/common/converter_op_utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/node_util.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -166,6 +167,9 @@ STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr<CNodeT> &node,
if (opType == PrimitiveType_Conv2D) { if (opType == PrimitiveType_Conv2D) {
weightTensor->format = Format_KCHW; weightTensor->format = Format_KCHW;
} else if (opType == PrimitiveType_DepthwiseConv2D) { } else if (opType == PrimitiveType_DepthwiseConv2D) {
if (weightTensor->format == Format_KCHW) {
TransFilterFormat<float>(weightTensor.get(), kKCHW2CKHW);
}
weightTensor->format = Format_CKHW; weightTensor->format = Format_CKHW;
} else { } else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;