!48544 add op adapter

Merge pull request !48544 from qiuzhongya/op_adapter
This commit is contained in:
i-robot 2023-02-08 06:47:10 +00:00 committed by Gitee
commit b8b68f4b0d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 24 additions and 2 deletions

View File

@ -167,7 +167,8 @@ constexpr const char kNameExpm1[] = "Expm1";
constexpr const char kNameInplaceAddD[] = "InplaceAdd"; constexpr const char kNameInplaceAddD[] = "InplaceAdd";
constexpr const char kNameInplaceSubD[] = "InplaceSub"; constexpr const char kNameInplaceSubD[] = "InplaceSub";
constexpr const char kNameInplaceUpdateD[] = "InplaceUpdate"; constexpr const char kNameInplaceUpdateD[] = "InplaceUpdate";
constexpr const char kNameInTopKD[] = "InTopK"; constexpr const char kNameInTopK[] = "InTopK";
constexpr const char kNameInTopKD[] = "InTopKD";
constexpr const char kNameInv[] = "Inv"; constexpr const char kNameInv[] = "Inv";
constexpr const char kNameInvGrad[] = "InvGrad"; constexpr const char kNameInvGrad[] = "InvGrad";
constexpr const char kNameInvert[] = "Invert"; constexpr const char kNameInvert[] = "Invert";
@ -304,6 +305,7 @@ constexpr const char kNamePrint[] = "Print";
constexpr const char kNameApplyFtrl[] = "ApplyFtrl"; constexpr const char kNameApplyFtrl[] = "ApplyFtrl";
constexpr const char kNameDiag[] = "Diag"; constexpr const char kNameDiag[] = "Diag";
constexpr const char kNameDiagPart[] = "DiagPart"; constexpr const char kNameDiagPart[] = "DiagPart";
constexpr const char kNameDiagPartD[] = "DiagPartD";
constexpr const char kNameSpaceToBatch[] = "SpaceToBatch"; constexpr const char kNameSpaceToBatch[] = "SpaceToBatch";
constexpr const char kNameBatchToSpace[] = "BatchToSpace"; constexpr const char kNameBatchToSpace[] = "BatchToSpace";
constexpr const char kNameTan[] = "Tan"; constexpr const char kNameTan[] = "Tan";

View File

@ -42,6 +42,12 @@ REG_ADPT_DESC(Add, prim::kPrimAdd->name(),
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})), std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})),
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})))) std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}}))))
// AddV2
INPUT_MAP(AddV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(AddV2) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AddV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(AddV2, prim::kPrimAddV2->name(), ADPT_DESC(AddV2))
// AccumulateNV2 // AccumulateNV2
INPUT_MAP(AccumulateNV2) = EMPTY_INPUT_MAP; INPUT_MAP(AccumulateNV2) = EMPTY_INPUT_MAP;
DYN_INPUT_MAP(AccumulateNV2) = {{1, DYN_INPUT_DESC(x)}}; DYN_INPUT_MAP(AccumulateNV2) = {{1, DYN_INPUT_DESC(x)}};

View File

@ -196,6 +196,9 @@ DECLARE_OP_USE_OUTPUT(Assign)
DECLARE_OP_ADAPTER(Add) DECLARE_OP_ADAPTER(Add)
DECLARE_OP_USE_OUTPUT(Add) DECLARE_OP_USE_OUTPUT(Add)
DECLARE_OP_ADAPTER(AddV2)
DECLARE_OP_USE_OUTPUT(AddV2)
DECLARE_OP_ADAPTER(Cos) DECLARE_OP_ADAPTER(Cos)
DECLARE_OP_USE_OUTPUT(Cos) DECLARE_OP_USE_OUTPUT(Cos)

View File

@ -142,6 +142,7 @@ INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}};
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP; ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(DiagPart, kNameDiagPart, ADPT_DESC(DiagPart)) REG_ADPT_DESC(DiagPart, kNameDiagPart, ADPT_DESC(DiagPart))
REG_ADPT_DESC(DiagPartD, kNameDiagPartD, ADPT_DESC(DiagPart))
// BatchMatMul // BatchMatMul
INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};

View File

@ -207,6 +207,12 @@ OUTPUT_MAP(HardShrink) = {{0, OUTPUT_DESC(output_y)}};
REG_ADPT_DESC(HShrink, prim::kPrimHShrink->name(), ADPT_DESC(HardShrink)) REG_ADPT_DESC(HShrink, prim::kPrimHShrink->name(), ADPT_DESC(HardShrink))
REG_ADPT_DESC(HardShrink, kHardShrinkOpName, ADPT_DESC(HardShrink)) REG_ADPT_DESC(HardShrink, kHardShrinkOpName, ADPT_DESC(HardShrink))
// HardShrinkGrad
INPUT_MAP(HardShrinkGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}};
ATTR_MAP(HardShrinkGrad) = {{"lambd", ATTR_DESC(lambd, AnyTraits<float>())}};
OUTPUT_MAP(HardShrinkGrad) = {{0, OUTPUT_DESC(backprops)}};
REG_ADPT_DESC(HardShrinkGrad, kHardShrinkGradOpName, ADPT_DESC(HardShrinkGrad))
// SoftShrink // SoftShrink
INPUT_MAP(SoftShrink) = {{1, INPUT_DESC(input_x)}}; INPUT_MAP(SoftShrink) = {{1, INPUT_DESC(input_x)}};
ATTR_MAP(SoftShrink) = {{"lambd", ATTR_DESC(lambd, AnyTraits<float>())}}; ATTR_MAP(SoftShrink) = {{"lambd", ATTR_DESC(lambd, AnyTraits<float>())}};

View File

@ -109,6 +109,9 @@ DECLARE_OP_USE_OUTPUT(LeakyRelu)
DECLARE_OP_ADAPTER(HardShrink) DECLARE_OP_ADAPTER(HardShrink)
DECLARE_OP_USE_OUTPUT(HardShrink) DECLARE_OP_USE_OUTPUT(HardShrink)
DECLARE_OP_ADAPTER(HardShrinkGrad)
DECLARE_OP_USE_OUTPUT(HardShrinkGrad)
DECLARE_OP_ADAPTER(SoftShrink) DECLARE_OP_ADAPTER(SoftShrink)
DECLARE_OP_USE_OUTPUT(SoftShrink) DECLARE_OP_USE_OUTPUT(SoftShrink)

View File

@ -73,8 +73,8 @@ REG_ADPT_DESC(TopKV2, kNameTopKV2, ADPT_DESC(TopK))
INPUT_MAP(InTopKD) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; INPUT_MAP(InTopKD) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(InTopKD) = {{"k", ATTR_DESC(k, AnyTraits<int64_t>())}}; ATTR_MAP(InTopKD) = {{"k", ATTR_DESC(k, AnyTraits<int64_t>())}};
OUTPUT_MAP(InTopKD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(InTopKD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(InTopK, kNameInTopK, ADPT_DESC(InTopKD))
REG_ADPT_DESC(InTopKD, kNameInTopKD, ADPT_DESC(InTopKD)) REG_ADPT_DESC(InTopKD, kNameInTopKD, ADPT_DESC(InTopKD))
// OneHot // OneHot
INPUT_MAP(OneHot) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(depth)}, {3, INPUT_DESC(on_value)}, {4, INPUT_DESC(off_value)}}; INPUT_MAP(OneHot) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(depth)}, {3, INPUT_DESC(on_value)}, {4, INPUT_DESC(off_value)}};
ATTR_INPUT_MAP(OneHot) = {{"depth", "depth"}}; ATTR_INPUT_MAP(OneHot) = {{"depth", "depth"}};

View File

@ -135,6 +135,7 @@ ATTR_INPUT_MAP(BatchToSpaceND) = {{"block_shape", "block_shape"}, {"crops", "cro
ATTR_MAP(BatchToSpaceND) = EMPTY_ATTR_MAP; ATTR_MAP(BatchToSpaceND) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BatchToSpaceND) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(BatchToSpaceND) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BatchToSpaceND, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceND)) REG_ADPT_DESC(BatchToSpaceND, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceND))
REG_ADPT_DESC(BatchToSpaceNDD, kBatchToSpaceNDDOpName, ADPT_DESC(BatchToSpaceND))
REG_ADPT_DESC(BatchToSpaceTF, kNameBatchToSpaceTF, ADPT_DESC(BatchToSpaceND)) REG_ADPT_DESC(BatchToSpaceTF, kNameBatchToSpaceTF, ADPT_DESC(BatchToSpaceND))
REG_ADPT_DESC(kNameBatchToSpaceNdV2, kNameBatchToSpaceNdV2, ADPT_DESC(BatchToSpaceND)) REG_ADPT_DESC(kNameBatchToSpaceNdV2, kNameBatchToSpaceNdV2, ADPT_DESC(BatchToSpaceND))
} // namespace mindspore::transform } // namespace mindspore::transform