forked from mindspore-Ecosystem/mindspore
!24063 npu InstanceNorm support nhwc input
Merge pull request !24063 from zhaozhenlong/lite/issue/npu-instance-nhwc
This commit is contained in:
commit
91aefa5baf
|
@ -25,7 +25,7 @@ namespace mindspore {
|
|||
std::set<mindspore::schema::PrimitiveType> nchw_nodes = {
|
||||
schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, schema::PrimitiveType_Resize,
|
||||
schema::PrimitiveType_MaxPoolFusion, schema::PrimitiveType_AvgPoolFusion, schema::PrimitiveType_ScaleFusion,
|
||||
schema::PrimitiveType_CropAndResize};
|
||||
schema::PrimitiveType_CropAndResize, schema::PrimitiveType_InstanceNorm};
|
||||
|
||||
int NPUTransformPass::InsertPreNodes(NPUOp *op, std::vector<NPUOp *> *trans_ops) {
|
||||
bool is_input_op = op->in_ops().empty();
|
||||
|
@ -171,6 +171,20 @@ int NPUTransformPass::Run(NPUGraph *subgraph) {
|
|||
i++;
|
||||
continue;
|
||||
}
|
||||
if (op->type() == schema::PrimitiveType_InstanceNorm) {
|
||||
if (op->inputs().empty()) {
|
||||
MS_LOG(ERROR) << op->name() << " inputs empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (op->inputs().front().format() == mindspore::Format::NCHW) {
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
if (op->inputs().front().format() != mindspore::Format::NHWC) {
|
||||
MS_LOG(ERROR) << "instance_norm input[0] should be NHWC or NCHW";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (op->type() == schema::PrimitiveType_Resize && op->inputs()[0].Shape()[1] > op->outputs()[0].Shape()[1]) {
|
||||
i++;
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue