forked from mindspore-Ecosystem/mindspore
!10076 add MaxPool2D/1D unify_mindir
From: @yuchaojie Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
be57909336
|
@ -73,8 +73,6 @@ const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph
|
||||||
auto shapes = {output_shape, argmax_shape};
|
auto shapes = {output_shape, argmax_shape};
|
||||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get());
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get());
|
||||||
|
|
||||||
auto manager = graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
return maxpool_with_argmax;
|
return maxpool_with_argmax;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,8 +105,6 @@ const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &g
|
||||||
argmax_shape[2] = ksize[1] * ksize[2];
|
argmax_shape[2] = ksize[1] * ksize[2];
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get());
|
AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get());
|
||||||
|
|
||||||
auto manager = graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
return maxpool_grad_with_argmax;
|
return maxpool_grad_with_argmax;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -439,6 +439,7 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
||||||
}
|
}
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
|
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
|
||||||
|
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>());
|
||||||
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());
|
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());
|
||||||
|
|
|
@ -124,15 +124,8 @@ class MaxPool2d(_PoolNd):
|
||||||
strides=self.stride,
|
strides=self.stride,
|
||||||
padding=self.pad_mode,
|
padding=self.pad_mode,
|
||||||
data_format=self.format)
|
data_format=self.format)
|
||||||
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
|
|
||||||
strides=self.stride,
|
|
||||||
padding=self.pad_mode)
|
|
||||||
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
if self.is_tbe and self.training:
|
|
||||||
out = self.max_pool_with_arg_max(x)[0]
|
|
||||||
else:
|
|
||||||
out = self.max_pool(x)
|
out = self.max_pool(x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -198,21 +191,14 @@ class MaxPool1d(_PoolNd):
|
||||||
self.max_pool = P.MaxPool(ksize=self.kernel_size,
|
self.max_pool = P.MaxPool(ksize=self.kernel_size,
|
||||||
strides=self.stride,
|
strides=self.stride,
|
||||||
padding=self.pad_mode)
|
padding=self.pad_mode)
|
||||||
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
|
|
||||||
strides=self.stride,
|
|
||||||
padding=self.pad_mode)
|
|
||||||
self.shape = F.shape
|
self.shape = F.shape
|
||||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||||
self.expand = P.ExpandDims()
|
self.expand = P.ExpandDims()
|
||||||
self.squeeze = P.Squeeze(2)
|
self.squeeze = P.Squeeze(2)
|
||||||
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_shape_check(self.shape(x))
|
_shape_check(self.shape(x))
|
||||||
x = self.expand(x, 2)
|
x = self.expand(x, 2)
|
||||||
if self.is_tbe and self.training:
|
|
||||||
output = self.max_pool_with_arg_max(x)[0]
|
|
||||||
else:
|
|
||||||
output = self.max_pool(x)
|
output = self.max_pool(x)
|
||||||
output = self.squeeze(output)
|
output = self.squeeze(output)
|
||||||
return output
|
return output
|
||||||
|
|
Loading…
Reference in New Issue