forked from mindspore-Ecosystem/mindspore
!9596 add infer function for fused sparse adam
From: @liubuyu Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
d88aa05859
|
@ -49,6 +49,8 @@ AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitiveP
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -252,6 +252,21 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return out->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// the output is useless, so we dont have to focus on the output shape
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[3]);
|
||||
|
||||
auto dx = args_spec_list[1]->Broaden();
|
||||
auto dscale = args_spec_list[2]->Broaden();
|
||||
auto dbias = args_spec_list[3]->Broaden();
|
||||
|
||||
AbstractBasePtrList rets = {dx, dscale, dbias};
|
||||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three tensors(doutput, input, filters).
|
||||
|
|
|
@ -101,6 +101,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimPooling, {InferImplPooling, true}},
|
||||
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
|
||||
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
|
||||
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
|
||||
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
|
||||
{prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}},
|
||||
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
|
||||
|
|
|
@ -140,6 +140,7 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive
|
|||
inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool");
|
||||
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
|
||||
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
|
||||
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");
|
||||
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx");
|
||||
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
|
|
Loading…
Reference in New Issue