forked from mindspore-Ecosystem/mindspore
!48871 Add Some math ops for bpropexpander 2.0
Merge pull request !48871 from ZengZitao/bexp_op_2
This commit is contained in:
commit
e6ea014082
|
@ -141,6 +141,7 @@ void RegMathBpropExpanderOps2() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(NPUClearFloatStatus);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ScalarCast);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Logit);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(LuUnpack);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Floor);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Ceil);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Square);
|
||||
|
@ -148,6 +149,7 @@ void RegMathBpropExpanderOps2() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(Trunc);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Ger);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Cross);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Median);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Erfinv);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Bernoulli);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ComplexAbs);
|
||||
|
@ -155,6 +157,11 @@ void RegMathBpropExpanderOps2() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(Imag);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Complex);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(MinimumGrad);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(AddN);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Sinc);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(MatrixPower);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TridiagonalMatMul);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(LpNorm);
|
||||
}
|
||||
|
||||
void RegNNBpropExpanderOps1() {
|
||||
|
|
|
@ -33,8 +33,12 @@ NodePtrList CompareBpropExpander(const BpropIRBuilder *ib) {
|
|||
|
||||
NodePtrList AddnGradFunc(const BpropIRBuilder *ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto n = LongToSize(ib->GetAttr<int64_t>("n"));
|
||||
NodePtrList result(n, dout);
|
||||
auto x_abs = ib->GetInput(kIndex0)->get()->abstract();
|
||||
auto x_len = x_abs->cast<abstract::AbstractSequencePtr>()->elements().size();
|
||||
NodePtrList result(x_len, dout);
|
||||
if (x_abs->isa<abstract::AbstractList>()) {
|
||||
return {ib->MakeList(result)};
|
||||
}
|
||||
return {ib->MakeTuple(result)};
|
||||
}
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ class MS_CORE_API Emitter {
|
|||
NodePtr EmitValue(const ValuePtr &value) const;
|
||||
|
||||
NodePtr MakeTuple(const NodePtrList &inputs) const { return Emit(prim::kMakeTuple, inputs); }
|
||||
NodePtr MakeList(const NodePtrList &inputs) const { return Emit("make_list", inputs); }
|
||||
NodePtr TupleGetItem(const NodePtr &input, size_t i) const {
|
||||
return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue