!24063 npu InstanceNorm support nhwc input

Merge pull request !24063 from zhaozhenlong/lite/issue/npu-instance-nhwc
This commit is contained in:
i-robot 2021-09-24 09:53:41 +00:00 committed by Gitee
commit 91aefa5baf
1 changed files with 15 additions and 1 deletions

View File

@ -25,7 +25,7 @@ namespace mindspore {
std::set<mindspore::schema::PrimitiveType> nchw_nodes = {
schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, schema::PrimitiveType_Resize,
schema::PrimitiveType_MaxPoolFusion, schema::PrimitiveType_AvgPoolFusion, schema::PrimitiveType_ScaleFusion,
schema::PrimitiveType_CropAndResize};
schema::PrimitiveType_CropAndResize, schema::PrimitiveType_InstanceNorm};
int NPUTransformPass::InsertPreNodes(NPUOp *op, std::vector<NPUOp *> *trans_ops) {
bool is_input_op = op->in_ops().empty();
@ -171,6 +171,20 @@ int NPUTransformPass::Run(NPUGraph *subgraph) {
i++;
continue;
}
if (op->type() == schema::PrimitiveType_InstanceNorm) {
if (op->inputs().empty()) {
MS_LOG(ERROR) << op->name() << " inputs empty";
return RET_ERROR;
}
if (op->inputs().front().format() == mindspore::Format::NCHW) {
i++;
continue;
}
if (op->inputs().front().format() != mindspore::Format::NHWC) {
MS_LOG(ERROR) << "instance_norm input[0] should be NHWC or NCHW";
return RET_ERROR;
}
}
if (op->type() == schema::PrimitiveType_Resize && op->inputs()[0].Shape()[1] > op->outputs()[0].Shape()[1]) {
i++;
continue;