anf exporter fixed
This commit is contained in:
parent
58523a41fe
commit
5ca7be576b
|
@ -387,7 +387,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
||||||
}
|
}
|
||||||
meta_graphT->allTensors.emplace_back(msTensor);
|
meta_graphT->allTensors.emplace_back(msTensor);
|
||||||
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D)
|
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D)
|
||||||
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)) {
|
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)
|
||||||
|
|| IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
|
void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
|
||||||
const std::unique_ptr<schema::PrimitiveT> &primitive,
|
const std::unique_ptr<schema::PrimitiveT> &primitive,
|
||||||
const int &group) {
|
const int &group, const std::vector<AnfNodePtr> &inputs) {
|
||||||
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||||
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
||||||
if (format == "NCHW") {
|
if (format == "NCHW") {
|
||||||
|
@ -66,6 +66,28 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
|
||||||
attr->padMode = schema::PadMode_NOTSET;
|
attr->padMode = schema::PadMode_NOTSET;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int channel_mutiplier = 1;
|
||||||
|
if (prim->GetAttr("channel_mutiplier") != nullptr) {
|
||||||
|
channel_mutiplier = GetValue<int>(prim->GetAttr("channel_multiplier"));
|
||||||
|
}
|
||||||
|
attr->channelMultiplier = channel_mutiplier;
|
||||||
|
|
||||||
|
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
|
||||||
|
auto inputNode = inputs[kAnfPopulaterOne];
|
||||||
|
MS_ASSERT(inputNode != nullptr);
|
||||||
|
if (inputNode->isa<Parameter>()) {
|
||||||
|
auto paramNode = inputNode->cast<ParameterPtr>();
|
||||||
|
auto abstractBase = paramNode->abstract();
|
||||||
|
MS_ASSERT(abstractBase != nullptr);
|
||||||
|
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||||
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||||
|
MS_ASSERT(abstractTensor != nullptr);
|
||||||
|
if (abstractTensor->format() == schema::Format_NCHW) {
|
||||||
|
abstractTensor->set_format(schema::Format_KCHW);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||||
primitive->value.value = attr.release();
|
primitive->value.value = attr.release();
|
||||||
}
|
}
|
||||||
|
@ -214,7 +236,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
|
||||||
|
|
||||||
int group = GetValue<int>(prim->GetAttr("group"));
|
int group = GetValue<int>(prim->GetAttr("group"));
|
||||||
if (group > 1) {
|
if (group > 1) {
|
||||||
PopulaterConv2DMultiGroup(prim, primitive, group);
|
PopulaterConv2DMultiGroup(prim, primitive, group, inputs);
|
||||||
} else {
|
} else {
|
||||||
PopulaterConv2DSingleGroup(prim, primitive, group);
|
PopulaterConv2DSingleGroup(prim, primitive, group);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ class AnfConvPopulater : public AnfNodePopulater {
|
||||||
private:
|
private:
|
||||||
void PopulaterConv2DMultiGroup(
|
void PopulaterConv2DMultiGroup(
|
||||||
const PrimitivePtr &prim,
|
const PrimitivePtr &prim,
|
||||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
|
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group, const std::vector<AnfNodePtr> &inputs);
|
||||||
void PopulaterConv2DSingleGroup(
|
void PopulaterConv2DSingleGroup(
|
||||||
const PrimitivePtr &prim,
|
const PrimitivePtr &prim,
|
||||||
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
|
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
|
||||||
|
|
|
@ -1129,7 +1129,12 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape);
|
||||||
|
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
auto primReturn = std::make_unique<schema::PrimitiveT>();
|
||||||
|
MS_ASSERT(primReturn != nullptr);
|
||||||
|
primReturn->value.type = schema::PrimitiveType_Return;
|
||||||
|
std::shared_ptr<PrimitiveTValue> primitiveTReturnValuePtr = std::make_shared<PrimitiveTValue>(primReturn.release());
|
||||||
|
MS_ASSERT(primitiveTReturnValuePtr != nullptr);
|
||||||
|
inputs.push_back(NewValueNode(primitiveTReturnValuePtr));
|
||||||
inputs.push_back(cnode_ptr);
|
inputs.push_back(cnode_ptr);
|
||||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(return_node);
|
MS_EXCEPTION_IF_NULL(return_node);
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "tools/common/converter_op_utils.h"
|
#include "tools/common/converter_op_utils.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
|
#include "tools/common/node_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -166,6 +167,9 @@ STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr<CNodeT> &node,
|
||||||
if (opType == PrimitiveType_Conv2D) {
|
if (opType == PrimitiveType_Conv2D) {
|
||||||
weightTensor->format = Format_KCHW;
|
weightTensor->format = Format_KCHW;
|
||||||
} else if (opType == PrimitiveType_DepthwiseConv2D) {
|
} else if (opType == PrimitiveType_DepthwiseConv2D) {
|
||||||
|
if (weightTensor->format == Format_KCHW) {
|
||||||
|
TransFilterFormat<float>(weightTensor.get(), kKCHW2CKHW);
|
||||||
|
}
|
||||||
weightTensor->format = Format_CKHW;
|
weightTensor->format = Format_CKHW;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
|
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
|
||||||
|
|
Loading…
Reference in New Issue