!10076 add MaxPool2D/1D unify_mindir

From: @yuchaojie
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2020-12-17 14:21:46 +08:00 committed by Gitee
commit be57909336
3 changed files with 3 additions and 20 deletions

View File

@ -73,8 +73,6 @@ const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph
auto shapes = {output_shape, argmax_shape}; auto shapes = {output_shape, argmax_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get());
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
return maxpool_with_argmax; return maxpool_with_argmax;
} }
@ -107,8 +105,6 @@ const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &g
argmax_shape[2] = ksize[1] * ksize[2]; argmax_shape[2] = ksize[1] * ksize[2];
AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get()); AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get());
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
return maxpool_grad_with_argmax; return maxpool_grad_with_argmax;
} }
} // namespace opt } // namespace opt

View File

@ -439,6 +439,7 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm"); auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>()); unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>()); unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>()); unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());

View File

@ -124,15 +124,8 @@ class MaxPool2d(_PoolNd):
strides=self.stride, strides=self.stride,
padding=self.pad_mode, padding=self.pad_mode,
data_format=self.format) data_format=self.format)
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
self.is_tbe = context.get_context("device_target") == "Ascend"
def construct(self, x): def construct(self, x):
if self.is_tbe and self.training:
out = self.max_pool_with_arg_max(x)[0]
else:
out = self.max_pool(x) out = self.max_pool(x)
return out return out
@ -198,21 +191,14 @@ class MaxPool1d(_PoolNd):
self.max_pool = P.MaxPool(ksize=self.kernel_size, self.max_pool = P.MaxPool(ksize=self.kernel_size,
strides=self.stride, strides=self.stride,
padding=self.pad_mode) padding=self.pad_mode)
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
self.shape = F.shape self.shape = F.shape
self.reduce_mean = P.ReduceMean(keep_dims=True) self.reduce_mean = P.ReduceMean(keep_dims=True)
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.squeeze = P.Squeeze(2) self.squeeze = P.Squeeze(2)
self.is_tbe = context.get_context("device_target") == "Ascend"
def construct(self, x): def construct(self, x):
_shape_check(self.shape(x)) _shape_check(self.shape(x))
x = self.expand(x, 2) x = self.expand(x, 2)
if self.is_tbe and self.training:
output = self.max_pool_with_arg_max(x)[0]
else:
output = self.max_pool(x) output = self.max_pool(x)
output = self.squeeze(output) output = self.squeeze(output)
return output return output