!9596 add infer function for fused sparse adam

From: @liubuyu
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2020-12-14 22:39:07 +08:00 committed by Gitee
commit d88aa05859
4 changed files with 19 additions and 0 deletions

View File

@ -49,6 +49,8 @@ AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -252,6 +252,21 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr
return out->Broaden();
}
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// the output is useless, so we dont have to focus on the output shape
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
MS_EXCEPTION_IF_NULL(args_spec_list[3]);
auto dx = args_spec_list[1]->Broaden();
auto dscale = args_spec_list[2]->Broaden();
auto dbias = args_spec_list[3]->Broaden();
AbstractBasePtrList rets = {dx, dscale, dbias};
return std::make_shared<AbstractTuple>(rets);
}
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors(doutput, input, filters).

View File

@ -101,6 +101,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},

View File

@ -140,6 +140,7 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive
inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool");
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx");
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");