!4361 fixed format trans

Merge pull request !4361 from wangchangkai/master
This commit is contained in:
mindspore-ci-bot 2020-08-13 14:47:19 +08:00 committed by Gitee
commit b59e072c91
4 changed files with 4 additions and 3 deletions

View File

@ -1168,6 +1168,7 @@ int AnfImporterFromProtobuf::Import() {
const onnx::GraphProto &graphBuild = onnx_model_->graph();
if (!BuildFuncGraph(dstGraph, graphBuild)) {
MS_LOG(ERROR) << "Build funcgraph failed!";
func_graph_ = nullptr;
return RET_ERROR;
}
func_graph_ = dstGraph;

View File

@ -96,7 +96,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm};
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm};
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
schema::PrimitiveType_Concat, schema::PrimitiveType_Add,

View File

@ -234,7 +234,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c));
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));

View File

@ -346,7 +346,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
if (fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms