!48871 Add Some math ops for bpropexpander 2.0

Merge pull request !48871 from ZengZitao/bexp_op_2
This commit is contained in:
i-robot 2023-02-16 06:50:42 +00:00 committed by Gitee
commit e6ea014082
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 14 additions and 2 deletions

View File

@ -141,6 +141,7 @@ void RegMathBpropExpanderOps2() {
REGISTER_EXPANDER_BPROP_IMPL(NPUClearFloatStatus);
REGISTER_EXPANDER_BPROP_IMPL(ScalarCast);
REGISTER_EXPANDER_BPROP_IMPL(Logit);
REGISTER_EXPANDER_BPROP_IMPL(LuUnpack);
REGISTER_EXPANDER_BPROP_IMPL(Floor);
REGISTER_EXPANDER_BPROP_IMPL(Ceil);
REGISTER_EXPANDER_BPROP_IMPL(Square);
@ -148,6 +149,7 @@ void RegMathBpropExpanderOps2() {
REGISTER_EXPANDER_BPROP_IMPL(Trunc);
REGISTER_EXPANDER_BPROP_IMPL(Ger);
REGISTER_EXPANDER_BPROP_IMPL(Cross);
REGISTER_EXPANDER_BPROP_IMPL(Median);
REGISTER_EXPANDER_BPROP_IMPL(Erfinv);
REGISTER_EXPANDER_BPROP_IMPL(Bernoulli);
REGISTER_EXPANDER_BPROP_IMPL(ComplexAbs);
@ -155,6 +157,11 @@ void RegMathBpropExpanderOps2() {
REGISTER_EXPANDER_BPROP_IMPL(Imag);
REGISTER_EXPANDER_BPROP_IMPL(Complex);
REGISTER_EXPANDER_BPROP_IMPL(MinimumGrad);
REGISTER_EXPANDER_BPROP_IMPL(AddN);
REGISTER_EXPANDER_BPROP_IMPL(Sinc);
REGISTER_EXPANDER_BPROP_IMPL(MatrixPower);
REGISTER_EXPANDER_BPROP_IMPL(TridiagonalMatMul);
REGISTER_EXPANDER_BPROP_IMPL(LpNorm);
}
void RegNNBpropExpanderOps1() {

View File

@ -33,8 +33,12 @@ NodePtrList CompareBpropExpander(const BpropIRBuilder *ib) {
NodePtrList AddnGradFunc(const BpropIRBuilder *ib) {
auto dout = ib->GetInput(kIndex2);
auto n = LongToSize(ib->GetAttr<int64_t>("n"));
NodePtrList result(n, dout);
auto x_abs = ib->GetInput(kIndex0)->get()->abstract();
auto x_len = x_abs->cast<abstract::AbstractSequencePtr>()->elements().size();
NodePtrList result(x_len, dout);
if (x_abs->isa<abstract::AbstractList>()) {
return {ib->MakeList(result)};
}
return {ib->MakeTuple(result)};
}

View File

@ -46,6 +46,7 @@ class MS_CORE_API Emitter {
NodePtr EmitValue(const ValuePtr &value) const;
NodePtr MakeTuple(const NodePtrList &inputs) const { return Emit(prim::kMakeTuple, inputs); }
NodePtr MakeList(const NodePtrList &inputs) const { return Emit("make_list", inputs); }
NodePtr TupleGetItem(const NodePtr &input, size_t i) const {
return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))});
}