forked from mindspore-Ecosystem/mindspore
!47470 fix unused inputs in grad_array & grad_sparse
Merge pull request !47470 from r1chardf1d0/master
This commit is contained in:
commit
6c9fd1bd63
|
@ -153,7 +153,7 @@ REG_BPROP_BUILDER("GatherD").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, ib->ZerosLike(dim), ib->ZerosLike(index)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto dim = GetValue<int64_t>(ib->GetAttr("dim"));
|
||||
auto x_shp = GetValue<ShapeVector>(ib->GetAttr("shape"));
|
||||
auto index = ib->GetInput(kIndex0);
|
||||
|
@ -187,7 +187,7 @@ REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {ib->ZerosLike(index), dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto dim = GetValue<int64_t>(ib->GetAttr("dim"));
|
||||
auto index = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
|
@ -225,7 +225,7 @@ REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {ib->ZerosLike(index), dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseGatherV2").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseGatherV2").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto axis = ib->GetInput(kIndex2);
|
||||
|
@ -323,24 +323,20 @@ REG_BPROP_BUILDER("Identity").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dout};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Range").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Range").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto start = ib->GetInput(kIndex0);
|
||||
auto limit = ib->GetInput(kIndex1);
|
||||
auto delta = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Pack").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
REG_BPROP_BUILDER("Pack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
|
||||
return {ret};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Stack").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
REG_BPROP_BUILDER("Stack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
|
||||
return {ret};
|
||||
|
@ -352,10 +348,9 @@ REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Unstack").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
REG_BPROP_BUILDER("Unstack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
out = ib->Emit("Stack", {dout}, {{"axis", ib->GetAttr("axis")}});
|
||||
auto out = ib->Emit("Stack", {dout}, {{"axis", ib->GetAttr("axis")}});
|
||||
return {out};
|
||||
});
|
||||
|
||||
|
@ -378,7 +373,7 @@ REG_BPROP_BUILDER("StridedSlice").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib)
|
|||
return {dx, dbegin, dend, dstrides};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i1, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto shapex = ib->GetInput(kIndex1);
|
||||
auto begin = ib->GetInput(kIndex2);
|
||||
auto end = ib->GetInput(kIndex3);
|
||||
|
@ -393,7 +388,7 @@ REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC
|
|||
ib->ZerosLike(shapex), ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Eye").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Eye").SetUnusedInputs({i0, i1, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto n = ib->GetInput(kIndex0);
|
||||
auto m = ib->GetInput(kIndex1);
|
||||
auto t = ib->GetInput(kIndex2);
|
||||
|
@ -408,17 +403,17 @@ REG_BPROP_BUILDER("Select").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
|||
return {ib->ZerosLike(cond), ib->Select(cond, dout, ib->ZerosLike(x)), ib->Select(cond, ib->ZerosLike(y), dout)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ZerosLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ZerosLike").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_shape = ib->GetShape(x);
|
||||
|
@ -431,7 +426,7 @@ REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i1}).SetBody(BODYFUN
|
|||
return {out};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -439,7 +434,7 @@ REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterNd").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ScatterNd").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto shape = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -467,7 +462,7 @@ REG_BPROP_BUILDER("TensorScatterUpdate").SetUnusedInputs({i0, i3}).SetBody(BODYF
|
|||
return {x_grad, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Reshape(dout, ib->GetShape(x));
|
||||
|
@ -482,7 +477,7 @@ REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib)
|
|||
return {ib->Reshape(dout, shapex), ib->ZerosLike(shp)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
@ -511,12 +506,12 @@ REG_BPROP_BUILDER("BatchMatMul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return BinopGradCommonWithShift(ib, x, w, dx, dw, 2);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Argmax").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Argmax").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Argmin").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Argmin").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
@ -655,7 +650,7 @@ REG_BPROP_BUILDER("TensorScatterMul").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib)
|
|||
REG_BPROP_BUILDER("TensorScatterMax").SetBody(TensorScatterPossibleReplacement);
|
||||
REG_BPROP_BUILDER("TensorScatterMin").SetBody(TensorScatterPossibleReplacement);
|
||||
|
||||
REG_BPROP_BUILDER("IndexFill").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("IndexFill").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dim = ib->GetInput(kIndex1);
|
||||
auto indices = ib->GetInput(kIndex2);
|
||||
|
@ -767,7 +762,7 @@ REG_BPROP_BUILDER("BatchToSpaceND").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(i
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto broadcast_shape = ib->GetAttr<ShapeVector>("shape");
|
||||
|
@ -820,13 +815,13 @@ REG_BPROP_BUILDER("ScatterUpdate").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUN
|
|||
return {dout, ib->ZerosLike(indices), ib->Emit("Gather", {dout, indices, ib->Tensor(0, kInt64)})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Fills").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Fills").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto value = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(value)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Cast").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Cast").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto t = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -835,7 +830,7 @@ REG_BPROP_BUILDER("Cast").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, ib->ZerosLike(t)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto axis = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -843,14 +838,14 @@ REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {ib->Reshape(dout, shapex), ib->ZerosLike(axis)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shapex = ib->GetShape(x);
|
||||
return {ib->Reshape(dout, shapex)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Padding").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Padding").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shp = ib->GetShape(x);
|
||||
|
@ -886,7 +881,7 @@ REG_BPROP_BUILDER("Split").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Tile").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Tile").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto input_multiples = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -916,14 +911,14 @@ REG_BPROP_BUILDER("Tile").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
REG_BPROP_BUILDER("Gather").SetUnusedInputs({i3}).SetBody(BinopGatherCommon);
|
||||
REG_BPROP_BUILDER("GatherV2").SetUnusedInputs({i3}).SetBody(BinopGatherCommon);
|
||||
|
||||
REG_BPROP_BUILDER("Fill").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Fill").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto dtype = ib->GetInput(kIndex0);
|
||||
auto dims = ib->GetInput(kIndex1);
|
||||
auto x = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(dtype), ib->ZerosLike(dims), ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto k = ib->GetInput(kIndex1);
|
||||
auto num_rows = ib->GetInput(kIndex2);
|
||||
auto num_cols = ib->GetInput(kIndex3);
|
||||
|
@ -934,7 +929,7 @@ REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib)
|
|||
return {part, ib->ZerosLike(k), ib->ZerosLike(num_rows), ib->ZerosLike(num_cols), ib->ZerosLike(padding_value)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto align = ib->GetAttr("align");
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto k = ib->GetInput(kIndex1);
|
||||
|
@ -949,7 +944,7 @@ REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib)
|
|||
return {diag, ib->ZerosLike(k), ib->ZerosLike(padding_value)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto align = ib->GetAttr("align");
|
||||
auto diagonal = ib->GetInput(kIndex1);
|
||||
auto k = ib->GetInput(kIndex2);
|
||||
|
@ -964,37 +959,37 @@ REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(
|
|||
return {x_cal, diagonal_cal, ib->ZerosLike(k)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("LogNormalReverse").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("LogNormalReverse").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto input_data = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(input_data)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Shape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Shape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Rank").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Rank").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DynamicShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("DynamicShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("TensorShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DType").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("DType").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto begin = ib->GetInput(kIndex1);
|
||||
auto end = ib->GetInput(kIndex2);
|
||||
|
@ -1010,7 +1005,7 @@ REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaskedFill").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MaskedFill").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto input_data = ib->GetInput(kIndex0);
|
||||
auto mask = ib->GetInput(kIndex1);
|
||||
auto value = ib->GetInput(kIndex2);
|
||||
|
@ -1066,7 +1061,7 @@ REG_BPROP_BUILDER("IdentityN").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dout};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto align_corners = GetValue<bool>(ib->GetAttr("align_corners"));
|
||||
auto half_pixel_centers = GetValue<bool>(ib->GetAttr("half_pixel_centers"));
|
||||
auto data_format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
|
@ -1106,7 +1101,7 @@ REG_BPROP_BUILDER("SegmentSum").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
|||
return {ib->Cast(ib->Emit("Gather", {dout, segment_ids, ib->Tensor(0)}), dout_type), ib->ZerosLike(segment_ids)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("EmbeddingLookup").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("EmbeddingLookup").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto offset = ib->GetInput(kIndex2);
|
||||
|
@ -1147,7 +1142,7 @@ REG_BPROP_BUILDER("SplitV").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto ksizes = GetValue<std::vector<int64_t>>(ib->GetAttr("kernel_size"));
|
||||
auto dilations = GetValue<std::vector<int64_t>>(ib->GetAttr("dilation"));
|
||||
auto strides = GetValue<std::vector<int64_t>>(ib->GetAttr("stride"));
|
||||
|
@ -1163,7 +1158,7 @@ REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, ib->ZerosLike(output_size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ExtractVolumePatches").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ExtractVolumePatches").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto ksize = GetValue<std::vector<int64_t>>(ib->GetAttr("kernel_size"));
|
||||
auto ksize_d = ksize.at(2);
|
||||
auto ksize_h = ksize.at(3);
|
||||
|
@ -1213,7 +1208,7 @@ REG_BPROP_BUILDER("ExtractVolumePatches").SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AffineGrid").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("AffineGrid").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto align_corners = GetValue<bool>(ib->GetAttr("align_corners"));
|
||||
auto theta = ib->GetInput(kIndex0);
|
||||
auto output_size = GetIntList(ib->GetInput(kIndex1));
|
||||
|
@ -1331,7 +1326,7 @@ REG_BPROP_BUILDER("ScatterAddWithAxis").SetUnusedInputs({i0, i2, i3}).SetBody(BO
|
|||
return {dout, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Expand").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Expand").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dout_shape = ib->GetShape(dout);
|
||||
|
|
|
@ -151,14 +151,14 @@ NodePtrList CommonSparseSegmentBpropForCpu(const BpropIRBuilder *ib, bool with_s
|
|||
}
|
||||
} // namespace
|
||||
REG_BPROP_BUILDERS_BEGIN(GradSparseOps)
|
||||
REG_BPROP_BUILDER("SparseToDense").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseToDense").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto dense_shape = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {ib->ZerosLike(indices), ib->Emit("GatherNd", {dout, indices}), ib->ZerosLike(dense_shape)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseToDenseV2").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseToDenseV2").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto output_shape = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
|
@ -214,7 +214,7 @@ REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetUnusedInputs({i4}).SetBody(BODYF
|
|||
return {ib->ZerosLike(indices), values_grad, ib->ZerosLike(dense_shape), dense_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseAdd").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseAdd").SetUnusedInputs({i1, i2, i4, i5, i6}).SetBody(BODYFUNC(ib) {
|
||||
auto x1_indices = ib->GetInput(kIndex0);
|
||||
auto x1_values = ib->GetInput(kIndex1);
|
||||
auto x1_shape = ib->GetInput(kIndex2);
|
||||
|
@ -359,19 +359,19 @@ REG_BPROP_BUILDER("CSRDiv").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
dense_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSR2COO").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("CSR2COO").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto nnz = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(indptr), ib->ZerosLike(nnz)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("COO2CSR").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("COO2CSR").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto row_indices = ib->GetInput(kIndex0);
|
||||
auto height = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(row_indices), ib->ZerosLike(height)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MakeCOOTensor").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MakeCOOTensor").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto dout_values = ib->TupleGetItem(dout, kIndex1);
|
||||
|
@ -394,12 +394,12 @@ REG_BPROP_BUILDER("COOTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(i
|
|||
return {ib->MakeTuple({ib->ZerosLike(coo_tensor_indices), dout, coo_tensor_shape})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("COOTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("COOTensorGetDenseShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto coo_tensor = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(coo_tensor)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MakeCSRTensor").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MakeCSRTensor").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
|
@ -435,7 +435,7 @@ REG_BPROP_BUILDER("CSRTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(i
|
|||
return {ib->MakeTuple({ib->ZerosLike(csr_tensor_indptr), ib->ZerosLike(csr_tensor_indices), dout, csr_tensor_shape})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto csr_tensor = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(csr_tensor)};
|
||||
});
|
||||
|
@ -453,7 +453,7 @@ REG_BPROP_BUILDER("CSRSparseMatrixToDense").SetUnusedInputs({i5}).SetBody(BODYFU
|
|||
ib->TupleGetItem(res, kIndex3), ib->TupleGetItem(res, kIndex4)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -485,7 +485,7 @@ REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody(BODYFUNC(ib) {
|
|||
return {ib->Emit("CSRSparseMatrixToDense", {shape, batch_ptr, row_ptr, col_ind, dvalue}), ib->ZerosLike(indices)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSoftmax").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseSoftmax").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto values = ib->GetInput(kIndex1);
|
||||
auto shape = ib->GetInput(kIndex2);
|
||||
|
@ -518,29 +518,29 @@ REG_BPROP_BUILDER("CSRSparseMatrixToSparseTensor").SetUnusedInputs({i0, i1, i2,
|
|||
ib->TupleGetItem(dx, kIndex3), ib->TupleGetItem(dx, kIndex4)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtN").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtN").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", false);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", true);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSum").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseSegmentSum").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
if (ib->GetTargetFromContext() == kGPUDevice) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", false);
|
||||
}
|
||||
return CommonSparseSegmentBpropForCpu(ib, false);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
if (ib->GetTargetFromContext() == kGPUDevice) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", true);
|
||||
}
|
||||
return CommonSparseSegmentBpropForCpu(ib, true);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseTensorDenseAdd").SetUnusedInputs({i1, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseTensorDenseAdd").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x1_indices = ib->GetInput(kIndex0);
|
||||
auto x1_shape = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
|
|
Loading…
Reference in New Issue