!11122 add adapter of Asin, Asinh, Atan etc. operators for graphengine.

From: @wangshuide2020
Reviewed-by: @liangchenghui,@c_34
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-01-11 09:56:57 +08:00 committed by Gitee
commit 59758b9a28
7 changed files with 259 additions and 0 deletions

View File

@ -31,6 +31,7 @@ constexpr const char kNameSimpleMean[] = "SimpleMean";
constexpr const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; constexpr const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
constexpr const char kNameAllReduce[] = "AllReduce"; constexpr const char kNameAllReduce[] = "AllReduce";
constexpr const char kNameBroadcast[] = "Broadcast"; constexpr const char kNameBroadcast[] = "Broadcast";
constexpr const char kNameBroadcastTo[] = "BroadcastTo";
constexpr const char kNameAllgather[] = "AllGather"; constexpr const char kNameAllgather[] = "AllGather";
constexpr const char kNameReduceScatter[] = "ReduceScatter"; constexpr const char kNameReduceScatter[] = "ReduceScatter";
constexpr const char kNameReduceSum[] = "ReduceSum"; constexpr const char kNameReduceSum[] = "ReduceSum";
@ -52,6 +53,7 @@ constexpr const char kNameLogicalOr[] = "LogicalOr";
constexpr const char kNameExp[] = "Exp"; constexpr const char kNameExp[] = "Exp";
constexpr const char kNameLessEqual[] = "LessEqual"; constexpr const char kNameLessEqual[] = "LessEqual";
constexpr const char kNameGreaterEqual[] = "GreaterEqual"; constexpr const char kNameGreaterEqual[] = "GreaterEqual";
constexpr const char kNameApproximateEqual[] = "ApproximateEqual";
constexpr const char kNameEqual[] = "Equal"; constexpr const char kNameEqual[] = "Equal";
constexpr const char kNameNotEqual[] = "NotEqual"; constexpr const char kNameNotEqual[] = "NotEqual";
constexpr const char kNameFlattenGrad[] = "FlattenGrad"; constexpr const char kNameFlattenGrad[] = "FlattenGrad";
@ -75,6 +77,12 @@ constexpr const char kNameConfusionMatrix[] = "ConfusionMatrix";
constexpr const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor"; constexpr const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor";
constexpr const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad"; constexpr const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad";
constexpr const char kNameApplyAdam[] = "Adam"; constexpr const char kNameApplyAdam[] = "Adam";
constexpr const char kNameApplyAdagrad[] = "ApplyAdagrad";
constexpr const char kNameApplyAdadelta[] = "ApplyAdadelta";
constexpr const char kNameApplyAdaMax[] = "ApplyAdaMax";
constexpr const char kNameApplyGradientDescent[] = "ApplyGradientDescent";
constexpr const char kNameApplyPowerSign[] = "ApplyPowerSign";
constexpr const char kNameApplyProximalGradientDescent[] = "ApplyProximalGradientDescent";
constexpr const char kNameExtractImagePatches[] = "ExtractImagePatches"; constexpr const char kNameExtractImagePatches[] = "ExtractImagePatches";
constexpr const char kNameReLU6[] = "ReLU6"; constexpr const char kNameReLU6[] = "ReLU6";
constexpr const char kNameReLU6Grad[] = "ReLU6Grad"; constexpr const char kNameReLU6Grad[] = "ReLU6Grad";
@ -116,13 +124,26 @@ constexpr const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
constexpr const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; constexpr const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
constexpr const char kNameReshape[] = "Reshape"; constexpr const char kNameReshape[] = "Reshape";
constexpr const char kNameTransShape[] = "TransShape"; constexpr const char kNameTransShape[] = "TransShape";
constexpr const char kNameDiv[] = "Div";
constexpr const char kNameRealDiv[] = "RealDiv"; constexpr const char kNameRealDiv[] = "RealDiv";
constexpr const char kNameBitwiseAnd[] = "BitwiseAnd";
constexpr const char kNameBitwiseOr[] = "BitwiseOr";
constexpr const char kNameBitwiseXor[] = "BitwiseXor";
constexpr const char kNameCeil[] = "Ceil";
constexpr const char kNameCosineEmbeddingLoss[] = "CosineEmbeddingLoss";
constexpr const char kNameXdivy[] = "Xdivy";
constexpr const char kNameTile[] = "Tile"; constexpr const char kNameTile[] = "Tile";
constexpr const char kNameCos[] = "Cos"; constexpr const char kNameCos[] = "Cos";
constexpr const char kNameCosh[] = "Cosh";
constexpr const char kNameACos[] = "ACos"; constexpr const char kNameACos[] = "ACos";
constexpr const char kNameACosGrad[] = "ACosGrad"; constexpr const char kNameACosGrad[] = "ACosGrad";
constexpr const char kNameFloorDiv[] = "FloorDiv"; constexpr const char kNameFloorDiv[] = "FloorDiv";
constexpr const char kNameSin[] = "Sin"; constexpr const char kNameSin[] = "Sin";
constexpr const char kNameSinh[] = "Sinh";
constexpr const char kNameAsin[] = "Asin";
constexpr const char kNameAsinGrad[] = "AsinGrad";
constexpr const char kNameAsinh[] = "Asinh";
constexpr const char kNameAsinhGrad[] = "AsinhGrad";
constexpr const char kNamePrelu[] = "PReLU"; constexpr const char kNamePrelu[] = "PReLU";
constexpr const char kNamePreluGrad[] = "PReLUGrad"; constexpr const char kNamePreluGrad[] = "PReLUGrad";
constexpr const char kNameSigmoid[] = "Sigmoid"; constexpr const char kNameSigmoid[] = "Sigmoid";
@ -180,6 +201,10 @@ constexpr const char kNameDiag[] = "Diag";
constexpr const char kNameDiagPart[] = "DiagPart"; constexpr const char kNameDiagPart[] = "DiagPart";
constexpr const char kNameSpaceToBatch[] = "SpaceToBatch"; constexpr const char kNameSpaceToBatch[] = "SpaceToBatch";
constexpr const char kNameBatchToSpace[] = "BatchToSpace"; constexpr const char kNameBatchToSpace[] = "BatchToSpace";
constexpr const char kNameTan[] = "Tan";
constexpr const char kNameAtan[] = "Atan";
constexpr const char kNameAtanGrad[] = "AtanGrad";
constexpr const char kNameAtanh[] = "Atanh";
constexpr const char kNameAtan2[] = "Atan2"; constexpr const char kNameAtan2[] = "Atan2";
constexpr const char kNameApplyRMSProp[] = "ApplyRMSProp"; constexpr const char kNameApplyRMSProp[] = "ApplyRMSProp";
constexpr const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; constexpr const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";

View File

@ -58,6 +58,12 @@ ATTR_MAP(Cos) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Cos) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Cos) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Cos, kNameCos, ADPT_DESC(Cos)) REG_ADPT_DESC(Cos, kNameCos, ADPT_DESC(Cos))
// Cosh
INPUT_MAP(Cosh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Cosh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Cosh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Cosh, kNameCosh, ADPT_DESC(Cosh))
// Acos // Acos
INPUT_MAP(Acos) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Acos) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Acos) = EMPTY_ATTR_MAP; ATTR_MAP(Acos) = EMPTY_ATTR_MAP;
@ -82,6 +88,12 @@ ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}}; OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AcoshGrad, kNameAcoshGrad, ADPT_DESC(AcoshGrad)) REG_ADPT_DESC(AcoshGrad, kNameAcoshGrad, ADPT_DESC(AcoshGrad))
// Div
INPUT_MAP(Div) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Div) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Div) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Div, kNameDiv, ADPT_DESC(Div))
// Floor // Floor
INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Floor) = EMPTY_ATTR_MAP; ATTR_MAP(Floor) = EMPTY_ATTR_MAP;
@ -106,6 +118,73 @@ ATTR_MAP(Sin) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Sin) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Sin) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Sin, kNameSin, ADPT_DESC(Sin)) REG_ADPT_DESC(Sin, kNameSin, ADPT_DESC(Sin))
// Sinh
INPUT_MAP(Sinh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Sinh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Sinh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Sinh, kNameSinh, ADPT_DESC(Sinh))
// Asin
INPUT_MAP(Asin) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Asin) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Asin) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Asin, kNameAsin, ADPT_DESC(Asin))
// AsinGrad
INPUT_MAP(AsinGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AsinGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AsinGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AsinGrad, kNameAsinGrad, ADPT_DESC(AsinGrad))
// Asinh
INPUT_MAP(Asinh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Asinh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Asinh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Asinh, kNameAsinh, ADPT_DESC(Asinh))
// AsinhGrad
INPUT_MAP(AsinhGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AsinhGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AsinhGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AsinhGrad, kNameAsinhGrad, ADPT_DESC(AsinhGrad))
// BitwiseAnd
INPUT_MAP(BitwiseAnd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(BitwiseAnd) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BitwiseAnd) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BitwiseAnd, kNameBitwiseAnd, ADPT_DESC(BitwiseAnd))
// BitwiseOr
INPUT_MAP(BitwiseOr) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(BitwiseOr) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BitwiseOr) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BitwiseOr, kNameBitwiseOr, ADPT_DESC(BitwiseOr))
// BitwiseXor
INPUT_MAP(BitwiseXor) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(BitwiseXor) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BitwiseXor) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BitwiseXor, kNameBitwiseXor, ADPT_DESC(BitwiseXor))
// Ceil
INPUT_MAP(Ceil) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Ceil) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Ceil) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Ceil, kNameCeil, ADPT_DESC(Ceil))
// CosineEmbeddingLoss
INPUT_MAP(CosineEmbeddingLoss) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(target)}};
ATTR_MAP(CosineEmbeddingLoss) = {{"margin", ATTR_DESC(margin, AnyTraits<float>())},
{"reduction", ATTR_DESC(reduction, AnyTraits<std::string>())}};
OUTPUT_MAP(CosineEmbeddingLoss) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(CosineEmbeddingLoss, kNameCosineEmbeddingLoss, ADPT_DESC(CosineEmbeddingLoss))
// Xdivy
INPUT_MAP(Xdivy) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Xdivy) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Xdivy) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Xdivy, kNameXdivy, ADPT_DESC(Xdivy))
// Exp // Exp
INPUT_MAP(Exp) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Exp) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Exp) = EMPTY_ATTR_MAP; ATTR_MAP(Exp) = EMPTY_ATTR_MAP;
@ -291,6 +370,12 @@ ATTR_MAP(Equal) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Equal) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Equal) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Equal, kNameEqual, ADPT_DESC(Equal)) REG_ADPT_DESC(Equal, kNameEqual, ADPT_DESC(Equal))
// ApproximateEqual
INPUT_MAP(ApproximateEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(ApproximateEqual) = {{"tolerance", ATTR_DESC(tolerance, AnyTraits<float>())}};
OUTPUT_MAP(ApproximateEqual) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ApproximateEqual, kNameApproximateEqual, ADPT_DESC(ApproximateEqual))
// NotEqual // NotEqual
INPUT_MAP(NotEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; INPUT_MAP(NotEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(NotEqual) = EMPTY_ATTR_MAP; ATTR_MAP(NotEqual) = EMPTY_ATTR_MAP;
@ -357,6 +442,30 @@ ATTR_MAP(Round) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Round, kNameRound, ADPT_DESC(Round)) REG_ADPT_DESC(Round, kNameRound, ADPT_DESC(Round))
// Tan
INPUT_MAP(Tan) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Tan) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Tan) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Tan, kNameTan, ADPT_DESC(Tan))
// Atan
INPUT_MAP(Atan) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Atan) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Atan) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Atan, kNameAtan, ADPT_DESC(Atan))
// AtanGrad
INPUT_MAP(AtanGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AtanGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AtanGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AtanGrad, kNameAtanGrad, ADPT_DESC(AtanGrad))
// Atanh
INPUT_MAP(Atanh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Atanh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Atanh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Atanh, kNameAtanh, ADPT_DESC(Atanh))
// Atan2 // Atan2
INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;

View File

@ -87,6 +87,24 @@ DECLARE_OP_USE_OUTPUT(MinimumGrad)
DECLARE_OP_ADAPTER(RealDiv) DECLARE_OP_ADAPTER(RealDiv)
DECLARE_OP_USE_OUTPUT(RealDiv) DECLARE_OP_USE_OUTPUT(RealDiv)
DECLARE_OP_ADAPTER(BitwiseAnd)
DECLARE_OP_USE_OUTPUT(BitwiseAnd)
DECLARE_OP_ADAPTER(BitwiseOr)
DECLARE_OP_USE_OUTPUT(BitwiseOr)
DECLARE_OP_ADAPTER(BitwiseXor)
DECLARE_OP_USE_OUTPUT(BitwiseXor)
DECLARE_OP_ADAPTER(Ceil)
DECLARE_OP_USE_OUTPUT(Ceil)
DECLARE_OP_ADAPTER(CosineEmbeddingLoss)
DECLARE_OP_USE_OUTPUT(CosineEmbeddingLoss)
DECLARE_OP_ADAPTER(Xdivy)
DECLARE_OP_USE_OUTPUT(Xdivy)
DECLARE_OP_ADAPTER(Cast) DECLARE_OP_ADAPTER(Cast)
DECLARE_OP_USE_INPUT_ATTR(Cast) DECLARE_OP_USE_INPUT_ATTR(Cast)
DECLARE_OP_USE_OUTPUT(Cast) DECLARE_OP_USE_OUTPUT(Cast)
@ -106,6 +124,9 @@ DECLARE_OP_USE_OUTPUT(Pow)
DECLARE_OP_ADAPTER(Equal) DECLARE_OP_ADAPTER(Equal)
DECLARE_OP_USE_OUTPUT(Equal) DECLARE_OP_USE_OUTPUT(Equal)
DECLARE_OP_ADAPTER(ApproximateEqual)
DECLARE_OP_USE_OUTPUT(ApproximateEqual)
DECLARE_OP_ADAPTER(NotEqual) DECLARE_OP_ADAPTER(NotEqual)
DECLARE_OP_USE_OUTPUT(NotEqual) DECLARE_OP_USE_OUTPUT(NotEqual)
@ -133,6 +154,9 @@ DECLARE_OP_USE_OUTPUT(Add)
DECLARE_OP_ADAPTER(Cos) DECLARE_OP_ADAPTER(Cos)
DECLARE_OP_USE_OUTPUT(Cos) DECLARE_OP_USE_OUTPUT(Cos)
DECLARE_OP_ADAPTER(Cosh)
DECLARE_OP_USE_OUTPUT(Cosh)
DECLARE_OP_ADAPTER(Acos) DECLARE_OP_ADAPTER(Acos)
DECLARE_OP_USE_OUTPUT(Acos) DECLARE_OP_USE_OUTPUT(Acos)
@ -145,6 +169,9 @@ DECLARE_OP_USE_OUTPUT(Acosh)
DECLARE_OP_ADAPTER(AcoshGrad) DECLARE_OP_ADAPTER(AcoshGrad)
DECLARE_OP_USE_OUTPUT(AcoshGrad) DECLARE_OP_USE_OUTPUT(AcoshGrad)
DECLARE_OP_ADAPTER(Div)
DECLARE_OP_USE_OUTPUT(Div)
DECLARE_OP_ADAPTER(Floor) DECLARE_OP_ADAPTER(Floor)
DECLARE_OP_USE_OUTPUT(Floor) DECLARE_OP_USE_OUTPUT(Floor)
@ -157,6 +184,21 @@ DECLARE_OP_USE_OUTPUT(FloorMod)
DECLARE_OP_ADAPTER(Sin) DECLARE_OP_ADAPTER(Sin)
DECLARE_OP_USE_OUTPUT(Sin) DECLARE_OP_USE_OUTPUT(Sin)
DECLARE_OP_ADAPTER(Sinh)
DECLARE_OP_USE_OUTPUT(Sinh)
DECLARE_OP_ADAPTER(Asin)
DECLARE_OP_USE_OUTPUT(Asin)
DECLARE_OP_ADAPTER(AsinGrad)
DECLARE_OP_USE_OUTPUT(AsinGrad)
DECLARE_OP_ADAPTER(Asinh)
DECLARE_OP_USE_OUTPUT(Asinh)
DECLARE_OP_ADAPTER(AsinhGrad)
DECLARE_OP_USE_OUTPUT(AsinhGrad)
DECLARE_OP_ADAPTER(Exp) DECLARE_OP_ADAPTER(Exp)
DECLARE_OP_USE_OUTPUT(Exp) DECLARE_OP_USE_OUTPUT(Exp)
@ -187,6 +229,18 @@ DECLARE_OP_USE_OUTPUT(Sign)
DECLARE_OP_ADAPTER(Round) DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round) DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(Tan)
DECLARE_OP_USE_OUTPUT(Tan)
DECLARE_OP_ADAPTER(Atan)
DECLARE_OP_USE_OUTPUT(Atan)
DECLARE_OP_ADAPTER(AtanGrad)
DECLARE_OP_USE_OUTPUT(AtanGrad)
DECLARE_OP_ADAPTER(Atanh)
DECLARE_OP_USE_OUTPUT(Atanh)
DECLARE_OP_ADAPTER(Atan2) DECLARE_OP_ADAPTER(Atan2)
DECLARE_OP_USE_OUTPUT(Atan2) DECLARE_OP_USE_OUTPUT(Atan2)

View File

@ -61,6 +61,50 @@ REG_ADPT_DESC(ApplyAdamD, kNameApplyAdam, ADPT_DESC(ApplyAdamD))
REG_ADPT_DESC(ApplyAdam, kNameApplyAdam, ADPT_DESC(ApplyAdam)) REG_ADPT_DESC(ApplyAdam, kNameApplyAdam, ADPT_DESC(ApplyAdam))
#endif #endif
// ApplyAdagradD
INPUT_MAP(ApplyAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdagradD) = {{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
REG_ADPT_DESC(ApplyAdagradD, kNameApplyAdagrad, ADPT_DESC(ApplyAdagradD))
// ApplyAdadeltaD
INPUT_MAP(ApplyAdadeltaD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(accum_update)},
{4, INPUT_DESC(lr)}, {5, INPUT_DESC(rho)}, {6, INPUT_DESC(epsilon)},
{7, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdadeltaD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdadeltaD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(accum_update)}};
REG_ADPT_DESC(ApplyAdadeltaD, kNameApplyAdadelta, ADPT_DESC(ApplyAdadeltaD))
// ApplyAdaMaxD
INPUT_MAP(ApplyAdaMaxD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)},
{4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(beta1)},
{7, INPUT_DESC(beta2)}, {8, INPUT_DESC(epsilon)}, {9, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdaMaxD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdaMaxD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
REG_ADPT_DESC(ApplyAdaMaxD, kNameApplyAdaMax, ADPT_DESC(ApplyAdaMaxD))
// ApplyGradientDescent
INPUT_MAP(ApplyGradientDescent) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(alpha)}, {3, INPUT_DESC(delta)}};
ATTR_MAP(ApplyGradientDescent) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyGradientDescent) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyGradientDescent, kNameApplyGradientDescent, ADPT_DESC(ApplyGradientDescent))
// ApplyPowerSignD
INPUT_MAP(ApplyPowerSignD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(lr)},
{4, INPUT_DESC(logbase)}, {5, INPUT_DESC(sign_decay)}, {6, INPUT_DESC(beta)},
{7, INPUT_DESC(grad)}};
ATTR_MAP(ApplyPowerSignD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyPowerSignD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}};
REG_ADPT_DESC(ApplyPowerSignD, kNameApplyPowerSign, ADPT_DESC(ApplyPowerSignD))
// ApplyProximalGradientDescent
INPUT_MAP(ApplyProximalGradientDescent) = {
{1, INPUT_DESC(var)}, {2, INPUT_DESC(alpha)}, {3, INPUT_DESC(l1)}, {4, INPUT_DESC(l2)}, {5, INPUT_DESC(delta)}};
ATTR_MAP(ApplyProximalGradientDescent) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyProximalGradientDescent) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyProximalGradientDescent, kNameApplyProximalGradientDescent, ADPT_DESC(ApplyProximalGradientDescent))
// SGD // SGD
INPUT_MAP(SGD) = {{1, INPUT_DESC(parameters)}, {2, INPUT_DESC(gradient)}, {3, INPUT_DESC(learning_rate)}, INPUT_MAP(SGD) = {{1, INPUT_DESC(parameters)}, {2, INPUT_DESC(gradient)}, {3, INPUT_DESC(learning_rate)},
{4, INPUT_DESC(accum)}, {5, INPUT_DESC(momentum)}, {6, INPUT_DESC(stat)}}; {4, INPUT_DESC(accum)}, {5, INPUT_DESC(momentum)}, {6, INPUT_DESC(stat)}};

View File

@ -29,6 +29,24 @@ DECLARE_OP_USE_OUTPUT(ApplyAdam)
DECLARE_OP_ADAPTER(ApplyAdamD) DECLARE_OP_ADAPTER(ApplyAdamD)
DECLARE_OP_USE_OUTPUT(ApplyAdamD) DECLARE_OP_USE_OUTPUT(ApplyAdamD)
DECLARE_OP_ADAPTER(ApplyAdagradD)
DECLARE_OP_USE_OUTPUT(ApplyAdagradD)
DECLARE_OP_ADAPTER(ApplyAdadeltaD)
DECLARE_OP_USE_OUTPUT(ApplyAdadeltaD)
DECLARE_OP_ADAPTER(ApplyAdaMaxD)
DECLARE_OP_USE_OUTPUT(ApplyAdaMaxD)
DECLARE_OP_ADAPTER(ApplyGradientDescent)
DECLARE_OP_USE_OUTPUT(ApplyGradientDescent)
DECLARE_OP_ADAPTER(ApplyPowerSignD)
DECLARE_OP_USE_OUTPUT(ApplyPowerSignD)
DECLARE_OP_ADAPTER(ApplyProximalGradientDescent)
DECLARE_OP_USE_OUTPUT(ApplyProximalGradientDescent)
DECLARE_OP_ADAPTER(SGD) DECLARE_OP_ADAPTER(SGD)
DECLARE_OP_USE_OUTPUT(SGD) DECLARE_OP_USE_OUTPUT(SGD)

View File

@ -24,6 +24,12 @@ ATTR_MAP(PadD) = {{"paddings", ATTR_DESC(paddings, AnyTraits<std::vector<std::ve
OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(PadD, kNamePadD, ADPT_DESC(PadD)) REG_ADPT_DESC(PadD, kNamePadD, ADPT_DESC(PadD))
// BroadcastToD
INPUT_MAP(BroadcastToD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(BroadcastToD) = {{"shape", ATTR_DESC(shape, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(BroadcastToD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BroadcastToD, kNameBroadcastTo, ADPT_DESC(BroadcastToD))
// Diag // Diag
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Diag) = EMPTY_ATTR_MAP; ATTR_MAP(Diag) = EMPTY_ATTR_MAP;

View File

@ -26,6 +26,9 @@ namespace mindspore::transform {
DECLARE_OP_ADAPTER(PadD) DECLARE_OP_ADAPTER(PadD)
DECLARE_OP_USE_OUTPUT(PadD) DECLARE_OP_USE_OUTPUT(PadD)
DECLARE_OP_ADAPTER(BroadcastToD)
DECLARE_OP_USE_OUTPUT(BroadcastToD)
DECLARE_OP_ADAPTER(Diag) DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag) DECLARE_OP_USE_OUTPUT(Diag)
} // namespace mindspore::transform } // namespace mindspore::transform