forked from mindspore-Ecosystem/mindspore
!4361 fixed format trans
Merge pull request !4361 from wangchangkai/master
This commit is contained in:
commit
b59e072c91
|
@ -1168,6 +1168,7 @@ int AnfImporterFromProtobuf::Import() {
|
|||
const onnx::GraphProto &graphBuild = onnx_model_->graph();
|
||||
if (!BuildFuncGraph(dstGraph, graphBuild)) {
|
||||
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||
func_graph_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
func_graph_ = dstGraph;
|
||||
|
|
|
@ -96,7 +96,7 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
|
|||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
|
||||
schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize,
|
||||
schema::PrimitiveType_BatchNorm};
|
||||
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
|
||||
schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
|
||||
|
|
|
@ -234,7 +234,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
|
|||
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
|
||||
} else if (type == kCKHW2KHWC) {
|
||||
p2Buff =
|
||||
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c));
|
||||
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
|
||||
} else {
|
||||
p2Buff =
|
||||
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
|
||||
|
|
|
@ -346,7 +346,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|||
// todo(00445839): consider varible weight condition
|
||||
}
|
||||
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
|
||||
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
|
||||
if (fmkType == converter::FmkType_MS) {
|
||||
weightTensor->format = schema::Format_CKHW;
|
||||
}
|
||||
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
|
||||
|
|
Loading…
Reference in New Issue