!24063 npu InstanceNorm support nhwc input
Merge pull request !24063 from zhaozhenlong/lite/issue/npu-instance-nhwc
This commit is contained in:
commit
91aefa5baf
|
@ -25,7 +25,7 @@ namespace mindspore {
|
||||||
std::set<mindspore::schema::PrimitiveType> nchw_nodes = {
|
std::set<mindspore::schema::PrimitiveType> nchw_nodes = {
|
||||||
schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, schema::PrimitiveType_Resize,
|
schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, schema::PrimitiveType_Resize,
|
||||||
schema::PrimitiveType_MaxPoolFusion, schema::PrimitiveType_AvgPoolFusion, schema::PrimitiveType_ScaleFusion,
|
schema::PrimitiveType_MaxPoolFusion, schema::PrimitiveType_AvgPoolFusion, schema::PrimitiveType_ScaleFusion,
|
||||||
schema::PrimitiveType_CropAndResize};
|
schema::PrimitiveType_CropAndResize, schema::PrimitiveType_InstanceNorm};
|
||||||
|
|
||||||
int NPUTransformPass::InsertPreNodes(NPUOp *op, std::vector<NPUOp *> *trans_ops) {
|
int NPUTransformPass::InsertPreNodes(NPUOp *op, std::vector<NPUOp *> *trans_ops) {
|
||||||
bool is_input_op = op->in_ops().empty();
|
bool is_input_op = op->in_ops().empty();
|
||||||
|
@ -171,6 +171,20 @@ int NPUTransformPass::Run(NPUGraph *subgraph) {
|
||||||
i++;
|
i++;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (op->type() == schema::PrimitiveType_InstanceNorm) {
|
||||||
|
if (op->inputs().empty()) {
|
||||||
|
MS_LOG(ERROR) << op->name() << " inputs empty";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (op->inputs().front().format() == mindspore::Format::NCHW) {
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (op->inputs().front().format() != mindspore::Format::NHWC) {
|
||||||
|
MS_LOG(ERROR) << "instance_norm input[0] should be NHWC or NCHW";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (op->type() == schema::PrimitiveType_Resize && op->inputs()[0].Shape()[1] > op->outputs()[0].Shape()[1]) {
|
if (op->type() == schema::PrimitiveType_Resize && op->inputs()[0].Shape()[1] > op->outputs()[0].Shape()[1]) {
|
||||||
i++;
|
i++;
|
||||||
continue;
|
continue;
|
||||||
|
|
Loading…
Reference in New Issue