forked from mindspore-Ecosystem/mindspore
!15181 fix_codex
From: @lingyunli63 Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
641ff8954b
|
@ -102,7 +102,8 @@ bool AkgKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build
|
|||
return true;
|
||||
}
|
||||
|
||||
kernel::KernelBuildClient *client = GetClient();
|
||||
auto client = GetClient();
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
if (!client->AkgStart(PROCESS_NUM, TIME_OUT)) {
|
||||
MS_LOG(ERROR) << "Akg start failed.";
|
||||
return false;
|
||||
|
|
|
@ -250,7 +250,7 @@ AnfNodePtr SimplifySelect(const AnfNodePtr &node) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
|
||||
AnfNodePtr SimplifyMul1(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMul)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -278,6 +278,28 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
|
|||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), node_tmp}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2)
|
||||
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda);
|
||||
// (x*C1)*C2 ==> x*(C1*C2)
|
||||
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2);
|
||||
// exp(x)*exp(y) ==> exp(x+y)
|
||||
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda);
|
||||
// sqrt(x)*sqrt(x) ==> x
|
||||
MATCH_REPLACE_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), x,
|
||||
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
// sqrt(x)*sqrt(y) ==> sqrt(x*y)
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y),
|
||||
sqrt_merge_lambda, !PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr SimplifyMul2(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMul)) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode<AnfNodePtr> x, y;
|
||||
PConstant<AnfNodePtr> const_1(node), const_2(node);
|
||||
|
||||
auto rsqrt_merge_lambda = [&node, &x]() -> AnfNodePtr {
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReciprocal), x.GetNode(node)}, node);
|
||||
return new_cnode;
|
||||
|
@ -296,18 +318,6 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
|
|||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2)
|
||||
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda);
|
||||
// (x*C1)*C2 ==> x*(C1*C2)
|
||||
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2);
|
||||
// exp(x)*exp(y) ==> exp(x+y)
|
||||
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda);
|
||||
// sqrt(x)*sqrt(x) ==> x
|
||||
MATCH_REPLACE_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), x,
|
||||
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
// sqrt(x)*sqrt(y) ==> sqrt(x*y)
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y),
|
||||
sqrt_merge_lambda, !PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
// rsqrt(x)*rsqrt(x) ==> 1/x
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y),
|
||||
rsqrt_merge_lambda, PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
|
@ -323,12 +333,12 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr SimplifyDiv(const AnfNodePtr &node) {
|
||||
AnfNodePtr SimplifyDiv1(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode<AnfNodePtr> x, y, u, v;
|
||||
PConstant<AnfNodePtr> const_1(node), const_2(node);
|
||||
PConstant<AnfNodePtr> const_1(node);
|
||||
PConstant<AnfNodePtr> const_one(node, false, 1);
|
||||
PConstant<AnfNodePtr> const_one_scalar(node, false, 1, true);
|
||||
|
||||
|
@ -353,6 +363,28 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) {
|
|||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
// x/1 ==> x
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x);
|
||||
MATCH_REPLACE(node, x / const_one, x);
|
||||
// e^x/e^y ==> e^(x-y)
|
||||
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_1);
|
||||
// x / e^y ==> x * e^(-y)
|
||||
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_2);
|
||||
// x / y^const ==> x * y^(-const)
|
||||
MATCH_REPLACE_LAMBDA(node, x / PBinOperation(prim::kPrimPow, y, const_1), div_pow_const);
|
||||
// x / sqrt(x) ==> sqrt(x)
|
||||
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_1,
|
||||
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr SimplifyDiv2(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode<AnfNodePtr> x, y, u, v;
|
||||
PConstant<AnfNodePtr> const_1(node);
|
||||
|
||||
auto div_sqrt_lambda_2 = [&node, &x, &y]() -> AnfNodePtr {
|
||||
auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), y.GetNode(node)}, node);
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node);
|
||||
|
@ -377,6 +409,25 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) {
|
|||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, new_rhs}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
// x / sqrt(y) ==> x * rsqrt(y)
|
||||
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_2,
|
||||
!PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
// x / rsqrt(y) ==> x * sqrt(y)
|
||||
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda);
|
||||
// // x / const ==> x * (1/const)
|
||||
MATCH_REPLACE_LAMBDA(node, x / const_1, div_const);
|
||||
// (x/y) / (u/v) ==> (x*v) / (y*u)
|
||||
MATCH_REPLACE_LAMBDA(node, (x / y) / (u / v), div_lambda_1);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr SimplifyDiv3(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode<AnfNodePtr> x, y, u, v;
|
||||
PConstant<AnfNodePtr> const_1(node), const_2(node);
|
||||
|
||||
auto div_lambda_2 = [&node, &x, &y, &u]() -> AnfNodePtr {
|
||||
auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node);
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node);
|
||||
|
@ -392,29 +443,8 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) {
|
|||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
// x/1 ==> x
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x);
|
||||
MATCH_REPLACE(node, x / const_one, x);
|
||||
// e^x/e^y ==> e^(x-y)
|
||||
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_1);
|
||||
// x / e^y ==> x * e^(-y)
|
||||
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_2);
|
||||
// x / y^const ==> x * y^(-const)
|
||||
MATCH_REPLACE_LAMBDA(node, x / PBinOperation(prim::kPrimPow, y, const_1), div_pow_const);
|
||||
// x / sqrt(x) ==> sqrt(x)
|
||||
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_1,
|
||||
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
// x / sqrt(y) ==> x * rsqrt(y)
|
||||
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_2,
|
||||
!PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
|
||||
// x / rsqrt(y) ==> x * sqrt(y)
|
||||
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda);
|
||||
// Neg(x) / const = x / (-const)
|
||||
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimNeg, x) / const_1, neg_div_lambda);
|
||||
// // x / const ==> x * (1/const)
|
||||
MATCH_REPLACE_LAMBDA(node, x / const_1, div_const);
|
||||
// (x/y) / (u/v) ==> (x*v) / (y*u)
|
||||
MATCH_REPLACE_LAMBDA(node, (x / y) / (u / v), div_lambda_1);
|
||||
// (x/y) / u ==> x / (y*u)
|
||||
MATCH_REPLACE_LAMBDA(node, (x / y) / u, div_lambda_2);
|
||||
// x / (u/v) ==> (x*v) / u
|
||||
|
@ -556,52 +586,22 @@ std::vector<std::pair<int64_t, int64_t>> GetUnmodifiedDim(const ShapeVector &a,
|
|||
return unmodified;
|
||||
}
|
||||
|
||||
AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimReduceMax) && !IsPrimitiveCNode(node, prim::kPrimReduceMin) &&
|
||||
!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
|
||||
std::list<PrimitivePtr> RedOps = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
|
||||
|
||||
bool IsRedOps(const AnfNodePtr &node) {
|
||||
if (std::any_of(RedOps.begin(), RedOps.end(),
|
||||
[&node](const PrimitivePtr &ops) { return IsPrimitiveCNode(node, ops); })) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions
|
||||
AnfNodePtr SimplifyReduce1(const AnfNodePtr &node) {
|
||||
if (!IsRedOps(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode<AnfNodePtr> x;
|
||||
auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||
auto shape = GetNodeShape(node);
|
||||
if (shape.size() != 0 && shape.size() != 1) {
|
||||
return nullptr;
|
||||
} else {
|
||||
auto tmp_node = node->cast<CNodePtr>();
|
||||
auto transpose_node = tmp_node->input(1);
|
||||
auto transpose_dimensions =
|
||||
GetValue<std::vector<int64_t>>(AnfAlgo::GetNodeAttr<ValuePtr>(transpose_node, "perm"));
|
||||
ShapeVector new_dimensions;
|
||||
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||
std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions),
|
||||
[&transpose_dimensions](const int64_t &dim) { return transpose_dimensions[dim]; });
|
||||
std::sort(new_dimensions.begin(), new_dimensions.end());
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||
return new_cnode;
|
||||
}
|
||||
};
|
||||
auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||
auto tmp_node = node->cast<CNodePtr>();
|
||||
auto arg_node = tmp_node->input(1);
|
||||
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis"));
|
||||
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||
ShapeVector new_dimensions;
|
||||
for (size_t i = 0; i < arg_dimensions.size(); ++i) {
|
||||
for (size_t j = 0; j < reduce_dimensions.size(); ++j) {
|
||||
if (reduce_dimensions[j] >= arg_dimensions[i]) {
|
||||
++reduce_dimensions[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(),
|
||||
std::back_inserter(new_dimensions));
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||
return new_cnode;
|
||||
};
|
||||
auto reshape_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||
auto tmp_node = node->cast<CNodePtr>();
|
||||
auto arg_node = tmp_node->input(1);
|
||||
|
@ -643,6 +643,37 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
|
|||
}
|
||||
return nullptr;
|
||||
};
|
||||
for (auto op : RedOps) {
|
||||
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(op, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lambda, op);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr SimplifyReduce2(const AnfNodePtr &node) {
|
||||
if (!IsRedOps(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode<AnfNodePtr> x;
|
||||
auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||
auto tmp_node = node->cast<CNodePtr>();
|
||||
auto arg_node = tmp_node->input(1);
|
||||
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis"));
|
||||
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||
ShapeVector new_dimensions;
|
||||
for (size_t i = 0; i < arg_dimensions.size(); ++i) {
|
||||
for (size_t j = 0; j < reduce_dimensions.size(); ++j) {
|
||||
if (reduce_dimensions[j] >= arg_dimensions[i]) {
|
||||
++reduce_dimensions[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(),
|
||||
std::back_inserter(new_dimensions));
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||
return new_cnode;
|
||||
};
|
||||
auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr {
|
||||
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node);
|
||||
AnfAlgo::CopyNodeAttr("axis", node, arg_node);
|
||||
|
@ -650,16 +681,9 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
|
|||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), arg_node}, node);
|
||||
return new_cnode;
|
||||
};
|
||||
std::list<PrimitivePtr> ReduceOperations = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
|
||||
for (auto operation : ReduceOperations) {
|
||||
// Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector
|
||||
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lambda,
|
||||
operation);
|
||||
for (auto operation : RedOps) {
|
||||
// Reduce(Reduce(A)) = Reduce(A)
|
||||
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lambda, operation);
|
||||
// Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions
|
||||
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lambda,
|
||||
operation);
|
||||
}
|
||||
// ReduceSum(Neg(x)) = Neg(ReduceSum(x))
|
||||
MATCH_REPLACE_LAMBDA(node, PPrimitive(prim::kPrimReduceSum, PUnaryOperation(prim::kPrimNeg, x)),
|
||||
|
@ -667,8 +691,41 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector
|
||||
AnfNodePtr SimplifyReduce3(const AnfNodePtr &node) {
|
||||
if (!IsRedOps(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode<AnfNodePtr> x;
|
||||
auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||
auto shape = GetNodeShape(node);
|
||||
if (shape.size() != 0 && shape.size() != 1) {
|
||||
return nullptr;
|
||||
} else {
|
||||
auto tmp_node = node->cast<CNodePtr>();
|
||||
auto transpose_node = tmp_node->input(1);
|
||||
auto transpose_dimensions =
|
||||
GetValue<std::vector<int64_t>>(AnfAlgo::GetNodeAttr<ValuePtr>(transpose_node, "perm"));
|
||||
ShapeVector new_dimensions;
|
||||
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||
std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions),
|
||||
[&transpose_dimensions](const int64_t &dim) { return transpose_dimensions[dim]; });
|
||||
std::sort(new_dimensions.begin(), new_dimensions.end());
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||
return new_cnode;
|
||||
}
|
||||
};
|
||||
for (auto operation : RedOps) {
|
||||
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lambda,
|
||||
operation);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr TrySimplify(const AnfNodePtr &node) {
|
||||
std::list<std::function<AnfNodePtr(AnfNodePtr)>> SimplifyFuncList = {SimplifyReduce};
|
||||
std::list<std::function<AnfNodePtr(const AnfNodePtr &)>> SimplifyFuncList = {SimplifyReduce1};
|
||||
for (auto f : SimplifyFuncList) {
|
||||
auto ret = f(node);
|
||||
if (ret != nullptr) {
|
||||
|
|
Loading…
Reference in New Issue