forked from mindspore-Ecosystem/mindspore
!48544 add op adapter
Merge pull request !48544 from qiuzhongya/op_adapter
This commit is contained in:
commit
b8b68f4b0d
|
@ -167,7 +167,8 @@ constexpr const char kNameExpm1[] = "Expm1";
|
||||||
constexpr const char kNameInplaceAddD[] = "InplaceAdd";
|
constexpr const char kNameInplaceAddD[] = "InplaceAdd";
|
||||||
constexpr const char kNameInplaceSubD[] = "InplaceSub";
|
constexpr const char kNameInplaceSubD[] = "InplaceSub";
|
||||||
constexpr const char kNameInplaceUpdateD[] = "InplaceUpdate";
|
constexpr const char kNameInplaceUpdateD[] = "InplaceUpdate";
|
||||||
constexpr const char kNameInTopKD[] = "InTopK";
|
constexpr const char kNameInTopK[] = "InTopK";
|
||||||
|
constexpr const char kNameInTopKD[] = "InTopKD";
|
||||||
constexpr const char kNameInv[] = "Inv";
|
constexpr const char kNameInv[] = "Inv";
|
||||||
constexpr const char kNameInvGrad[] = "InvGrad";
|
constexpr const char kNameInvGrad[] = "InvGrad";
|
||||||
constexpr const char kNameInvert[] = "Invert";
|
constexpr const char kNameInvert[] = "Invert";
|
||||||
|
@ -304,6 +305,7 @@ constexpr const char kNamePrint[] = "Print";
|
||||||
constexpr const char kNameApplyFtrl[] = "ApplyFtrl";
|
constexpr const char kNameApplyFtrl[] = "ApplyFtrl";
|
||||||
constexpr const char kNameDiag[] = "Diag";
|
constexpr const char kNameDiag[] = "Diag";
|
||||||
constexpr const char kNameDiagPart[] = "DiagPart";
|
constexpr const char kNameDiagPart[] = "DiagPart";
|
||||||
|
constexpr const char kNameDiagPartD[] = "DiagPartD";
|
||||||
constexpr const char kNameSpaceToBatch[] = "SpaceToBatch";
|
constexpr const char kNameSpaceToBatch[] = "SpaceToBatch";
|
||||||
constexpr const char kNameBatchToSpace[] = "BatchToSpace";
|
constexpr const char kNameBatchToSpace[] = "BatchToSpace";
|
||||||
constexpr const char kNameTan[] = "Tan";
|
constexpr const char kNameTan[] = "Tan";
|
||||||
|
|
|
@ -42,6 +42,12 @@ REG_ADPT_DESC(Add, prim::kPrimAdd->name(),
|
||||||
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})),
|
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}})),
|
||||||
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}}))))
|
std::make_shared<OpAdapter<Add>>(ExtraAttr({{"mode", MakeValue(static_cast<int64_t>(1))}}))))
|
||||||
|
|
||||||
|
// AddV2
|
||||||
|
INPUT_MAP(AddV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
||||||
|
ATTR_MAP(AddV2) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(AddV2) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
REG_ADPT_DESC(AddV2, prim::kPrimAddV2->name(), ADPT_DESC(AddV2))
|
||||||
|
|
||||||
// AccumulateNV2
|
// AccumulateNV2
|
||||||
INPUT_MAP(AccumulateNV2) = EMPTY_INPUT_MAP;
|
INPUT_MAP(AccumulateNV2) = EMPTY_INPUT_MAP;
|
||||||
DYN_INPUT_MAP(AccumulateNV2) = {{1, DYN_INPUT_DESC(x)}};
|
DYN_INPUT_MAP(AccumulateNV2) = {{1, DYN_INPUT_DESC(x)}};
|
||||||
|
|
|
@ -196,6 +196,9 @@ DECLARE_OP_USE_OUTPUT(Assign)
|
||||||
DECLARE_OP_ADAPTER(Add)
|
DECLARE_OP_ADAPTER(Add)
|
||||||
DECLARE_OP_USE_OUTPUT(Add)
|
DECLARE_OP_USE_OUTPUT(Add)
|
||||||
|
|
||||||
|
DECLARE_OP_ADAPTER(AddV2)
|
||||||
|
DECLARE_OP_USE_OUTPUT(AddV2)
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(Cos)
|
DECLARE_OP_ADAPTER(Cos)
|
||||||
DECLARE_OP_USE_OUTPUT(Cos)
|
DECLARE_OP_USE_OUTPUT(Cos)
|
||||||
|
|
||||||
|
|
|
@ -142,6 +142,7 @@ INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
|
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(DiagPart, kNameDiagPart, ADPT_DESC(DiagPart))
|
REG_ADPT_DESC(DiagPart, kNameDiagPart, ADPT_DESC(DiagPart))
|
||||||
|
REG_ADPT_DESC(DiagPartD, kNameDiagPartD, ADPT_DESC(DiagPart))
|
||||||
|
|
||||||
// BatchMatMul
|
// BatchMatMul
|
||||||
INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
||||||
|
|
|
@ -207,6 +207,12 @@ OUTPUT_MAP(HardShrink) = {{0, OUTPUT_DESC(output_y)}};
|
||||||
REG_ADPT_DESC(HShrink, prim::kPrimHShrink->name(), ADPT_DESC(HardShrink))
|
REG_ADPT_DESC(HShrink, prim::kPrimHShrink->name(), ADPT_DESC(HardShrink))
|
||||||
REG_ADPT_DESC(HardShrink, kHardShrinkOpName, ADPT_DESC(HardShrink))
|
REG_ADPT_DESC(HardShrink, kHardShrinkOpName, ADPT_DESC(HardShrink))
|
||||||
|
|
||||||
|
// HardShrinkGrad
|
||||||
|
INPUT_MAP(HardShrinkGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}};
|
||||||
|
ATTR_MAP(HardShrinkGrad) = {{"lambd", ATTR_DESC(lambd, AnyTraits<float>())}};
|
||||||
|
OUTPUT_MAP(HardShrinkGrad) = {{0, OUTPUT_DESC(backprops)}};
|
||||||
|
REG_ADPT_DESC(HardShrinkGrad, kHardShrinkGradOpName, ADPT_DESC(HardShrinkGrad))
|
||||||
|
|
||||||
// SoftShrink
|
// SoftShrink
|
||||||
INPUT_MAP(SoftShrink) = {{1, INPUT_DESC(input_x)}};
|
INPUT_MAP(SoftShrink) = {{1, INPUT_DESC(input_x)}};
|
||||||
ATTR_MAP(SoftShrink) = {{"lambd", ATTR_DESC(lambd, AnyTraits<float>())}};
|
ATTR_MAP(SoftShrink) = {{"lambd", ATTR_DESC(lambd, AnyTraits<float>())}};
|
||||||
|
|
|
@ -109,6 +109,9 @@ DECLARE_OP_USE_OUTPUT(LeakyRelu)
|
||||||
DECLARE_OP_ADAPTER(HardShrink)
|
DECLARE_OP_ADAPTER(HardShrink)
|
||||||
DECLARE_OP_USE_OUTPUT(HardShrink)
|
DECLARE_OP_USE_OUTPUT(HardShrink)
|
||||||
|
|
||||||
|
DECLARE_OP_ADAPTER(HardShrinkGrad)
|
||||||
|
DECLARE_OP_USE_OUTPUT(HardShrinkGrad)
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(SoftShrink)
|
DECLARE_OP_ADAPTER(SoftShrink)
|
||||||
DECLARE_OP_USE_OUTPUT(SoftShrink)
|
DECLARE_OP_USE_OUTPUT(SoftShrink)
|
||||||
|
|
||||||
|
|
|
@ -73,8 +73,8 @@ REG_ADPT_DESC(TopKV2, kNameTopKV2, ADPT_DESC(TopK))
|
||||||
INPUT_MAP(InTopKD) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
INPUT_MAP(InTopKD) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
||||||
ATTR_MAP(InTopKD) = {{"k", ATTR_DESC(k, AnyTraits<int64_t>())}};
|
ATTR_MAP(InTopKD) = {{"k", ATTR_DESC(k, AnyTraits<int64_t>())}};
|
||||||
OUTPUT_MAP(InTopKD) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(InTopKD) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
REG_ADPT_DESC(InTopK, kNameInTopK, ADPT_DESC(InTopKD))
|
||||||
REG_ADPT_DESC(InTopKD, kNameInTopKD, ADPT_DESC(InTopKD))
|
REG_ADPT_DESC(InTopKD, kNameInTopKD, ADPT_DESC(InTopKD))
|
||||||
|
|
||||||
// OneHot
|
// OneHot
|
||||||
INPUT_MAP(OneHot) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(depth)}, {3, INPUT_DESC(on_value)}, {4, INPUT_DESC(off_value)}};
|
INPUT_MAP(OneHot) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(depth)}, {3, INPUT_DESC(on_value)}, {4, INPUT_DESC(off_value)}};
|
||||||
ATTR_INPUT_MAP(OneHot) = {{"depth", "depth"}};
|
ATTR_INPUT_MAP(OneHot) = {{"depth", "depth"}};
|
||||||
|
|
|
@ -135,6 +135,7 @@ ATTR_INPUT_MAP(BatchToSpaceND) = {{"block_shape", "block_shape"}, {"crops", "cro
|
||||||
ATTR_MAP(BatchToSpaceND) = EMPTY_ATTR_MAP;
|
ATTR_MAP(BatchToSpaceND) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(BatchToSpaceND) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(BatchToSpaceND) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(BatchToSpaceND, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceND))
|
REG_ADPT_DESC(BatchToSpaceND, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceND))
|
||||||
|
REG_ADPT_DESC(BatchToSpaceNDD, kBatchToSpaceNDDOpName, ADPT_DESC(BatchToSpaceND))
|
||||||
REG_ADPT_DESC(BatchToSpaceTF, kNameBatchToSpaceTF, ADPT_DESC(BatchToSpaceND))
|
REG_ADPT_DESC(BatchToSpaceTF, kNameBatchToSpaceTF, ADPT_DESC(BatchToSpaceND))
|
||||||
REG_ADPT_DESC(kNameBatchToSpaceNdV2, kNameBatchToSpaceNdV2, ADPT_DESC(BatchToSpaceND))
|
REG_ADPT_DESC(kNameBatchToSpaceNdV2, kNameBatchToSpaceNdV2, ADPT_DESC(BatchToSpaceND))
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
|
|
Loading…
Reference in New Issue