!47470 fix unused inputs in grad_array & grad_sparse

Merge pull request !47470 from r1chardf1d0/master
This commit is contained in:
i-robot 2023-01-04 07:12:57 +00:00 committed by Gitee
commit 6c9fd1bd63
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 61 additions and 66 deletions

View File

@ -153,7 +153,7 @@ REG_BPROP_BUILDER("GatherD").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
return {dx, ib->ZerosLike(dim), ib->ZerosLike(index)};
});
REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
auto dim = GetValue<int64_t>(ib->GetAttr("dim"));
auto x_shp = GetValue<ShapeVector>(ib->GetAttr("shape"));
auto index = ib->GetInput(kIndex0);
@ -187,7 +187,7 @@ REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {ib->ZerosLike(index), dx};
});
REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
auto dim = GetValue<int64_t>(ib->GetAttr("dim"));
auto index = ib->GetInput(kIndex0);
auto x = ib->GetInput(kIndex1);
@ -225,7 +225,7 @@ REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {ib->ZerosLike(index), dx};
});
REG_BPROP_BUILDER("SparseGatherV2").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseGatherV2").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1);
auto axis = ib->GetInput(kIndex2);
@ -323,24 +323,20 @@ REG_BPROP_BUILDER("Identity").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
return {dout};
});
REG_BPROP_BUILDER("Range").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Range").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto start = ib->GetInput(kIndex0);
auto limit = ib->GetInput(kIndex1);
auto delta = ib->GetInput(kIndex2);
return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)};
});
REG_BPROP_BUILDER("Pack").SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1);
REG_BPROP_BUILDER("Pack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2);
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
return {ret};
});
REG_BPROP_BUILDER("Stack").SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1);
REG_BPROP_BUILDER("Stack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2);
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
return {ret};
@ -352,10 +348,9 @@ REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
return {dx};
});
REG_BPROP_BUILDER("Unstack").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
auto out = ib->GetInput(kIndex1);
REG_BPROP_BUILDER("Unstack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2);
out = ib->Emit("Stack", {dout}, {{"axis", ib->GetAttr("axis")}});
auto out = ib->Emit("Stack", {dout}, {{"axis", ib->GetAttr("axis")}});
return {out};
});
@ -378,7 +373,7 @@ REG_BPROP_BUILDER("StridedSlice").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib)
return {dx, dbegin, dend, dstrides};
});
REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i1, i5}).SetBody(BODYFUNC(ib) {
auto shapex = ib->GetInput(kIndex1);
auto begin = ib->GetInput(kIndex2);
auto end = ib->GetInput(kIndex3);
@ -393,7 +388,7 @@ REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC
ib->ZerosLike(shapex), ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)};
});
REG_BPROP_BUILDER("Eye").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Eye").SetUnusedInputs({i0, i1, i3, i4}).SetBody(BODYFUNC(ib) {
auto n = ib->GetInput(kIndex0);
auto m = ib->GetInput(kIndex1);
auto t = ib->GetInput(kIndex2);
@ -408,17 +403,17 @@ REG_BPROP_BUILDER("Select").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
return {ib->ZerosLike(cond), ib->Select(cond, dout, ib->ZerosLike(x)), ib->Select(cond, ib->ZerosLike(y), dout)};
});
REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("ZerosLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("ZerosLike").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto x_shape = ib->GetShape(x);
@ -431,7 +426,7 @@ REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i1}).SetBody(BODYFUN
return {out};
});
REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
@ -439,7 +434,7 @@ REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)};
});
REG_BPROP_BUILDER("ScatterNd").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("ScatterNd").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto indices = ib->GetInput(kIndex0);
auto shape = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4);
@ -467,7 +462,7 @@ REG_BPROP_BUILDER("TensorScatterUpdate").SetUnusedInputs({i0, i3}).SetBody(BODYF
return {x_grad, ib->ZerosLike(indices), update_grad};
});
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto dx = ib->Reshape(dout, ib->GetShape(x));
@ -482,7 +477,7 @@ REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib)
return {ib->Reshape(dout, shapex), ib->ZerosLike(shp)};
});
REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
@ -511,12 +506,12 @@ REG_BPROP_BUILDER("BatchMatMul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return BinopGradCommonWithShift(ib, x, w, dx, dw, 2);
});
REG_BPROP_BUILDER("Argmax").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Argmax").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("Argmin").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Argmin").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
@ -655,7 +650,7 @@ REG_BPROP_BUILDER("TensorScatterMul").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib)
REG_BPROP_BUILDER("TensorScatterMax").SetBody(TensorScatterPossibleReplacement);
REG_BPROP_BUILDER("TensorScatterMin").SetBody(TensorScatterPossibleReplacement);
REG_BPROP_BUILDER("IndexFill").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("IndexFill").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dim = ib->GetInput(kIndex1);
auto indices = ib->GetInput(kIndex2);
@ -767,7 +762,7 @@ REG_BPROP_BUILDER("BatchToSpaceND").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(i
return {dx};
});
REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto broadcast_shape = ib->GetAttr<ShapeVector>("shape");
@ -820,13 +815,13 @@ REG_BPROP_BUILDER("ScatterUpdate").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUN
return {dout, ib->ZerosLike(indices), ib->Emit("Gather", {dout, indices, ib->Tensor(0, kInt64)})};
});
REG_BPROP_BUILDER("Fills").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Fills").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto value = ib->GetInput(kIndex1);
return {ib->ZerosLike(x), ib->ZerosLike(value)};
});
REG_BPROP_BUILDER("Cast").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Cast").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto t = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
@ -835,7 +830,7 @@ REG_BPROP_BUILDER("Cast").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {dx, ib->ZerosLike(t)};
});
REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto axis = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
@ -843,14 +838,14 @@ REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {ib->Reshape(dout, shapex), ib->ZerosLike(axis)};
});
REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto shapex = ib->GetShape(x);
return {ib->Reshape(dout, shapex)};
});
REG_BPROP_BUILDER("Padding").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Padding").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto shp = ib->GetShape(x);
@ -886,7 +881,7 @@ REG_BPROP_BUILDER("Split").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
return {dx};
});
REG_BPROP_BUILDER("Tile").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Tile").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto input_multiples = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
@ -916,14 +911,14 @@ REG_BPROP_BUILDER("Tile").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Gather").SetUnusedInputs({i3}).SetBody(BinopGatherCommon);
REG_BPROP_BUILDER("GatherV2").SetUnusedInputs({i3}).SetBody(BinopGatherCommon);
REG_BPROP_BUILDER("Fill").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Fill").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto dtype = ib->GetInput(kIndex0);
auto dims = ib->GetInput(kIndex1);
auto x = ib->GetInput(kIndex2);
return {ib->ZerosLike(dtype), ib->ZerosLike(dims), ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
auto k = ib->GetInput(kIndex1);
auto num_rows = ib->GetInput(kIndex2);
auto num_cols = ib->GetInput(kIndex3);
@ -934,7 +929,7 @@ REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib)
return {part, ib->ZerosLike(k), ib->ZerosLike(num_rows), ib->ZerosLike(num_cols), ib->ZerosLike(padding_value)};
});
REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
auto align = ib->GetAttr("align");
auto x = ib->GetInput(kIndex0);
auto k = ib->GetInput(kIndex1);
@ -949,7 +944,7 @@ REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib)
return {diag, ib->ZerosLike(k), ib->ZerosLike(padding_value)};
});
REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i1, i3}).SetBody(BODYFUNC(ib) {
auto align = ib->GetAttr("align");
auto diagonal = ib->GetInput(kIndex1);
auto k = ib->GetInput(kIndex2);
@ -964,37 +959,37 @@ REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(
return {x_cal, diagonal_cal, ib->ZerosLike(k)};
});
REG_BPROP_BUILDER("LogNormalReverse").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("LogNormalReverse").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto input_data = ib->GetInput(kIndex0);
return {ib->ZerosLike(input_data)};
});
REG_BPROP_BUILDER("Shape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Shape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("Rank").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Rank").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("DynamicShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("DynamicShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("TensorShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("TensorShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("DType").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("DType").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)};
});
REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto begin = ib->GetInput(kIndex1);
auto end = ib->GetInput(kIndex2);
@ -1010,7 +1005,7 @@ REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
return {dx, ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)};
});
REG_BPROP_BUILDER("MaskedFill").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("MaskedFill").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
auto input_data = ib->GetInput(kIndex0);
auto mask = ib->GetInput(kIndex1);
auto value = ib->GetInput(kIndex2);
@ -1066,7 +1061,7 @@ REG_BPROP_BUILDER("IdentityN").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
return {dout};
});
REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto align_corners = GetValue<bool>(ib->GetAttr("align_corners"));
auto half_pixel_centers = GetValue<bool>(ib->GetAttr("half_pixel_centers"));
auto data_format = GetValue<std::string>(ib->GetAttr("format"));
@ -1106,7 +1101,7 @@ REG_BPROP_BUILDER("SegmentSum").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
return {ib->Cast(ib->Emit("Gather", {dout, segment_ids, ib->Tensor(0)}), dout_type), ib->ZerosLike(segment_ids)};
});
REG_BPROP_BUILDER("EmbeddingLookup").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("EmbeddingLookup").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1);
auto offset = ib->GetInput(kIndex2);
@ -1147,7 +1142,7 @@ REG_BPROP_BUILDER("SplitV").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
return {dx};
});
REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto ksizes = GetValue<std::vector<int64_t>>(ib->GetAttr("kernel_size"));
auto dilations = GetValue<std::vector<int64_t>>(ib->GetAttr("dilation"));
auto strides = GetValue<std::vector<int64_t>>(ib->GetAttr("stride"));
@ -1163,7 +1158,7 @@ REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
return {dx, ib->ZerosLike(output_size)};
});
REG_BPROP_BUILDER("ExtractVolumePatches").SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("ExtractVolumePatches").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto ksize = GetValue<std::vector<int64_t>>(ib->GetAttr("kernel_size"));
auto ksize_d = ksize.at(2);
auto ksize_h = ksize.at(3);
@ -1213,7 +1208,7 @@ REG_BPROP_BUILDER("ExtractVolumePatches").SetBody(BODYFUNC(ib) {
return {dx};
});
REG_BPROP_BUILDER("AffineGrid").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("AffineGrid").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto align_corners = GetValue<bool>(ib->GetAttr("align_corners"));
auto theta = ib->GetInput(kIndex0);
auto output_size = GetIntList(ib->GetInput(kIndex1));
@ -1331,7 +1326,7 @@ REG_BPROP_BUILDER("ScatterAddWithAxis").SetUnusedInputs({i0, i2, i3}).SetBody(BO
return {dout, ib->ZerosLike(indices), update_grad};
});
REG_BPROP_BUILDER("Expand").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("Expand").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex3);
auto dout_shape = ib->GetShape(dout);

View File

@ -151,14 +151,14 @@ NodePtrList CommonSparseSegmentBpropForCpu(const BpropIRBuilder *ib, bool with_s
}
} // namespace
REG_BPROP_BUILDERS_BEGIN(GradSparseOps)
REG_BPROP_BUILDER("SparseToDense").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseToDense").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto indices = ib->GetInput(kIndex0);
auto dense_shape = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4);
return {ib->ZerosLike(indices), ib->Emit("GatherNd", {dout, indices}), ib->ZerosLike(dense_shape)};
});
REG_BPROP_BUILDER("SparseToDenseV2").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseToDenseV2").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto indices = ib->GetInput(kIndex0);
auto output_shape = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex5);
@ -214,7 +214,7 @@ REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetUnusedInputs({i4}).SetBody(BODYF
return {ib->ZerosLike(indices), values_grad, ib->ZerosLike(dense_shape), dense_grad};
});
REG_BPROP_BUILDER("SparseAdd").SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseAdd").SetUnusedInputs({i1, i2, i4, i5, i6}).SetBody(BODYFUNC(ib) {
auto x1_indices = ib->GetInput(kIndex0);
auto x1_values = ib->GetInput(kIndex1);
auto x1_shape = ib->GetInput(kIndex2);
@ -359,19 +359,19 @@ REG_BPROP_BUILDER("CSRDiv").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
dense_grad};
});
REG_BPROP_BUILDER("CSR2COO").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("CSR2COO").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto indptr = ib->GetInput(kIndex0);
auto nnz = ib->GetInput(kIndex1);
return {ib->ZerosLike(indptr), ib->ZerosLike(nnz)};
});
REG_BPROP_BUILDER("COO2CSR").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("COO2CSR").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto row_indices = ib->GetInput(kIndex0);
auto height = ib->GetInput(kIndex1);
return {ib->ZerosLike(row_indices), ib->ZerosLike(height)};
});
REG_BPROP_BUILDER("MakeCOOTensor").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("MakeCOOTensor").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto indices = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex4);
auto dout_values = ib->TupleGetItem(dout, kIndex1);
@ -394,12 +394,12 @@ REG_BPROP_BUILDER("COOTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(i
return {ib->MakeTuple({ib->ZerosLike(coo_tensor_indices), dout, coo_tensor_shape})};
});
REG_BPROP_BUILDER("COOTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("COOTensorGetDenseShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto coo_tensor = ib->GetInput(kIndex0);
return {ib->ZerosLike(coo_tensor)};
});
REG_BPROP_BUILDER("MakeCSRTensor").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("MakeCSRTensor").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto indptr = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex5);
@ -435,7 +435,7 @@ REG_BPROP_BUILDER("CSRTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(i
return {ib->MakeTuple({ib->ZerosLike(csr_tensor_indptr), ib->ZerosLike(csr_tensor_indices), dout, csr_tensor_shape})};
});
REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto csr_tensor = ib->GetInput(kIndex0);
return {ib->ZerosLike(csr_tensor)};
});
@ -453,7 +453,7 @@ REG_BPROP_BUILDER("CSRSparseMatrixToDense").SetUnusedInputs({i5}).SetBody(BODYFU
ib->TupleGetItem(res, kIndex3), ib->TupleGetItem(res, kIndex4)};
});
REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto indices = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex3);
@ -485,7 +485,7 @@ REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody(BODYFUNC(ib) {
return {ib->Emit("CSRSparseMatrixToDense", {shape, batch_ptr, row_ptr, col_ind, dvalue}), ib->ZerosLike(indices)};
});
REG_BPROP_BUILDER("SparseSoftmax").SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseSoftmax").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
auto indices = ib->GetInput(kIndex0);
auto values = ib->GetInput(kIndex1);
auto shape = ib->GetInput(kIndex2);
@ -518,29 +518,29 @@ REG_BPROP_BUILDER("CSRSparseMatrixToSparseTensor").SetUnusedInputs({i0, i1, i2,
ib->TupleGetItem(dx, kIndex3), ib->TupleGetItem(dx, kIndex4)};
});
REG_BPROP_BUILDER("SparseSegmentSqrtN").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseSegmentSqrtN").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", false);
});
REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", true);
});
REG_BPROP_BUILDER("SparseSegmentSum").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseSegmentSum").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
if (ib->GetTargetFromContext() == kGPUDevice) {
return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", false);
}
return CommonSparseSegmentBpropForCpu(ib, false);
});
REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
if (ib->GetTargetFromContext() == kGPUDevice) {
return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", true);
}
return CommonSparseSegmentBpropForCpu(ib, true);
});
REG_BPROP_BUILDER("SparseTensorDenseAdd").SetUnusedInputs({i1, i3, i4}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("SparseTensorDenseAdd").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto x1_indices = ib->GetInput(kIndex0);
auto x1_shape = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex5);