From 91b29210bb6eb7b823b04e8021fd762100aaca46 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Sep 2021 15:35:36 +0800 Subject: [PATCH] npu InstanceNorm support nhwc input --- .../src/delegate/npu/pass/npu_transform_pass.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/delegate/npu/pass/npu_transform_pass.cc b/mindspore/lite/src/delegate/npu/pass/npu_transform_pass.cc index e87355c14c9..f9e4621a21a 100644 --- a/mindspore/lite/src/delegate/npu/pass/npu_transform_pass.cc +++ b/mindspore/lite/src/delegate/npu/pass/npu_transform_pass.cc @@ -25,7 +25,7 @@ namespace mindspore { std::set nchw_nodes = { schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, schema::PrimitiveType_Resize, schema::PrimitiveType_MaxPoolFusion, schema::PrimitiveType_AvgPoolFusion, schema::PrimitiveType_ScaleFusion, - schema::PrimitiveType_CropAndResize}; + schema::PrimitiveType_CropAndResize, schema::PrimitiveType_InstanceNorm}; int NPUTransformPass::InsertPreNodes(NPUOp *op, std::vector *trans_ops) { bool is_input_op = op->in_ops().empty(); @@ -171,6 +171,20 @@ int NPUTransformPass::Run(NPUGraph *subgraph) { i++; continue; } + if (op->type() == schema::PrimitiveType_InstanceNorm) { + if (op->inputs().empty()) { + MS_LOG(ERROR) << op->name() << " inputs empty"; + return RET_ERROR; + } + if (op->inputs().front().format() == mindspore::Format::NCHW) { + i++; + continue; + } + if (op->inputs().front().format() != mindspore::Format::NHWC) { + MS_LOG(ERROR) << "instance_norm input[0] should be NHWC or NCHW"; + return RET_ERROR; + } + } if (op->type() == schema::PrimitiveType_Resize && op->inputs()[0].Shape()[1] > op->outputs()[0].Shape()[1]) { i++; continue;