add splitable operators

This commit is contained in:
Yi Huaijie 2020-10-20 10:54:34 +08:00
parent e14a6fc015
commit fe1b06c659
2 changed files with 64 additions and 5 deletions

View File

@ -230,6 +230,47 @@ constexpr char ASSIGN[] = "Assign";
constexpr char GET_NEXT[] = "GetNext";
constexpr char SQUEEZE[] = "Squeeze";
constexpr char NEG[] = "Neg";
constexpr char ABS[] = "Abs";
constexpr char ACOSH[] = "Acosh";
constexpr char ASIN[] = "Asin";
constexpr char ASINH[] = "Asinh";
constexpr char ATAN[] = "Atan";
constexpr char ATANH[] = "Atanh";
constexpr char CEIL[] = "Ceil";
constexpr char COSH[] = "Cosh";
constexpr char EXPM1[] = "Expm1";
constexpr char LOG1P[] = "Log1p";
constexpr char SIN[] = "Sin";
constexpr char SINH[] = "Sinh";
constexpr char TAN[] = "Tan";
constexpr char RSQRT[] = "Rsqrt";
constexpr char INV[] = "Inv";
constexpr char RECIPROCAL[] = "Reciprocal";
constexpr char ROUND[] = "Round";
constexpr char FLOOR[] = "Floor";
constexpr char SIGN[] = "Sign";
constexpr char ERF[] = "Erf";
constexpr char ERFC[] = "Erfc";
constexpr char ZEROSLIKE[] = "ZerosLike";
constexpr char ONESLIKE[] = "OnesLike";
constexpr char BESSELI0E[] = "BesselI0e";
constexpr char BESSELI1E[] = "BesselI1e";
constexpr char FLOORMOD[] = "FloorMod";
constexpr char ASSIGN_ADD[] = "AssignAdd";
constexpr char ATAN2[] = "Atan2";
constexpr char DIVNONAN[] = "DivNoNan";
constexpr char LOGICALAND[] = "LogicalAnd";
constexpr char LOGICALOR[] = "LogicalOr";
constexpr char ELU[] = "Elu";
constexpr char RELU6[] = "ReLU6";
constexpr char RELUV2[] = "ReLUV2";
constexpr char SOFTPLUS[] = "Softplus";
constexpr char SOFTSIGN[] = "Softsign";
constexpr char GREATEREQUAL[] = "GreaterEqual";
constexpr char LESSEQUAL[] = "LessEqual";
constexpr char LESS[] = "Less";
constexpr char APPROXIMATEEQUAL[] = "ApproximateEqual";
constexpr char MOD[] = "Mod";
constexpr char BATCH_MATMUL[] = "BatchMatMul";
constexpr char EXPAND_DIMS[] = "ExpandDims";
constexpr char SQUARE[] = "Square";
@ -297,7 +338,6 @@ constexpr char COL2IMV1[] = "col2im_v1";
constexpr char RESOLVE[] = "resolve";
constexpr char EMBED[] = "embed";
constexpr char CREATINSTANCE[] = "create_instance";
constexpr char ZEROSLIKE[] = "ZerosLike";
constexpr char REF_TO_EMBED[] = "RefToEmbed";
constexpr char STOP_GRADIENT[] = "stop_gradient";

View File

@ -248,9 +248,25 @@ std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
}
bool IsElementWiseOperator(const std::string &op_name) {
static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU,
SQRT, CAST, POW, EXP, LOG, COS,
ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID};
static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH,
SOFTMAX, LOG_SOFTMAX, RELU,
SQRT, CAST, POW,
EXP, LOG, COS,
ACOS, LOGICALNOT, NEG,
SQUARE, SIGMOID, ABS,
ACOSH, ASIN, ASINH,
ATAN, ATANH, CEIL,
COSH, EXPM1, LOG1P,
SIN, SINH, TAN,
RSQRT, RECIPROCAL, INV,
ROUND, FLOOR, SIGN,
ERF, ERFC, ZEROSLIKE,
ONESLIKE, BESSELI0E, MOD,
ASSIGN, ASSIGN_ADD, ATAN2,
DIVNONAN, LOGICALAND, ELU,
LOGICALOR, RELU6, SOFTPLUS,
SOFTSIGN, LESS, LESSEQUAL,
BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL};
auto iter = elementwise_op.find(op_name);
return (iter != elementwise_op.end());
}
@ -265,7 +281,10 @@ bool IsSplittableOperator(const std::string &op_name) {
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO};
EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD};
// clang-format on
auto iter = splittable_op.find(op_name);