diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_array_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_array_ops.cc index 1878288abfe..efbec85e270 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_array_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_array_ops.cc @@ -153,7 +153,7 @@ REG_BPROP_BUILDER("GatherD").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { return {dx, ib->ZerosLike(dim), ib->ZerosLike(index)}; }); -REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { auto dim = GetValue(ib->GetAttr("dim")); auto x_shp = GetValue(ib->GetAttr("shape")); auto index = ib->GetInput(kIndex0); @@ -187,7 +187,7 @@ REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {ib->ZerosLike(index), dx}; }); -REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { auto dim = GetValue(ib->GetAttr("dim")); auto index = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex1); @@ -225,7 +225,7 @@ REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {ib->ZerosLike(index), dx}; }); -REG_BPROP_BUILDER("SparseGatherV2").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseGatherV2").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto indices = ib->GetInput(kIndex1); auto axis = ib->GetInput(kIndex2); @@ -323,24 +323,20 @@ REG_BPROP_BUILDER("Identity").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { return {dout}; }); -REG_BPROP_BUILDER("Range").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Range").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto start = ib->GetInput(kIndex0); auto limit = ib->GetInput(kIndex1); auto delta = ib->GetInput(kIndex2); return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)}; }); -REG_BPROP_BUILDER("Pack").SetBody(BODYFUNC(ib) { - auto x = ib->GetInput(kIndex0); - auto out = ib->GetInput(kIndex1); +REG_BPROP_BUILDER("Pack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto dout = ib->GetInput(kIndex2); auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}}); return {ret}; }); -REG_BPROP_BUILDER("Stack").SetBody(BODYFUNC(ib) { - auto x = ib->GetInput(kIndex0); - auto out = ib->GetInput(kIndex1); +REG_BPROP_BUILDER("Stack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto dout = ib->GetInput(kIndex2); auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}}); return {ret}; @@ -352,10 +348,9 @@ REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("Unstack").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { - auto out = ib->GetInput(kIndex1); +REG_BPROP_BUILDER("Unstack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto dout = ib->GetInput(kIndex2); - out = ib->Emit("Stack", {dout}, {{"axis", ib->GetAttr("axis")}}); + auto out = ib->Emit("Stack", {dout}, {{"axis", ib->GetAttr("axis")}}); return {out}; }); @@ -378,7 +373,7 @@ REG_BPROP_BUILDER("StridedSlice").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) return {dx, dbegin, dend, dstrides}; }); -REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i1, i5}).SetBody(BODYFUNC(ib) { auto shapex = ib->GetInput(kIndex1); auto begin = ib->GetInput(kIndex2); auto end = ib->GetInput(kIndex3); @@ -393,7 +388,7 @@ REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC ib->ZerosLike(shapex), ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)}; }); -REG_BPROP_BUILDER("Eye").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Eye").SetUnusedInputs({i0, i1, i3, i4}).SetBody(BODYFUNC(ib) { auto n = ib->GetInput(kIndex0); auto m = ib->GetInput(kIndex1); auto t = ib->GetInput(kIndex2); @@ -408,17 +403,17 @@ REG_BPROP_BUILDER("Select").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { return {ib->ZerosLike(cond), ib->Select(cond, dout, ib->ZerosLike(x)), ib->Select(cond, ib->ZerosLike(y), dout)}; }); -REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("ZerosLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ZerosLike").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto x_shape = ib->GetShape(x); @@ -431,7 +426,7 @@ REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i1}).SetBody(BODYFUN return {out}; }); -REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto indices = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); @@ -439,7 +434,7 @@ REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)}; }); -REG_BPROP_BUILDER("ScatterNd").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ScatterNd").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) { auto indices = ib->GetInput(kIndex0); auto shape = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex4); @@ -467,7 +462,7 @@ REG_BPROP_BUILDER("TensorScatterUpdate").SetUnusedInputs({i0, i3}).SetBody(BODYF return {x_grad, ib->ZerosLike(indices), update_grad}; }); -REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto dx = ib->Reshape(dout, ib->GetShape(x)); @@ -482,7 +477,7 @@ REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) return {ib->Reshape(dout, shapex), ib->ZerosLike(shp)}; }); -REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); @@ -511,12 +506,12 @@ REG_BPROP_BUILDER("BatchMatMul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return BinopGradCommonWithShift(ib, x, w, dx, dw, 2); }); -REG_BPROP_BUILDER("Argmax").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Argmax").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("Argmin").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Argmin").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); @@ -655,7 +650,7 @@ REG_BPROP_BUILDER("TensorScatterMul").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) REG_BPROP_BUILDER("TensorScatterMax").SetBody(TensorScatterPossibleReplacement); REG_BPROP_BUILDER("TensorScatterMin").SetBody(TensorScatterPossibleReplacement); -REG_BPROP_BUILDER("IndexFill").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("IndexFill").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dim = ib->GetInput(kIndex1); auto indices = ib->GetInput(kIndex2); @@ -767,7 +762,7 @@ REG_BPROP_BUILDER("BatchToSpaceND").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(i return {dx}; }); -REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto broadcast_shape = ib->GetAttr("shape"); @@ -820,13 +815,13 @@ REG_BPROP_BUILDER("ScatterUpdate").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUN return {dout, ib->ZerosLike(indices), ib->Emit("Gather", {dout, indices, ib->Tensor(0, kInt64)})}; }); -REG_BPROP_BUILDER("Fills").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Fills").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto value = ib->GetInput(kIndex1); return {ib->ZerosLike(x), ib->ZerosLike(value)}; }); -REG_BPROP_BUILDER("Cast").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Cast").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto t = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); @@ -835,7 +830,7 @@ REG_BPROP_BUILDER("Cast").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {dx, ib->ZerosLike(t)}; }); -REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto axis = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); @@ -843,14 +838,14 @@ REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {ib->Reshape(dout, shapex), ib->ZerosLike(axis)}; }); -REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto shapex = ib->GetShape(x); return {ib->Reshape(dout, shapex)}; }); -REG_BPROP_BUILDER("Padding").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Padding").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto shp = ib->GetShape(x); @@ -886,7 +881,7 @@ REG_BPROP_BUILDER("Split").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("Tile").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Tile").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto input_multiples = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); @@ -916,14 +911,14 @@ REG_BPROP_BUILDER("Tile").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Gather").SetUnusedInputs({i3}).SetBody(BinopGatherCommon); REG_BPROP_BUILDER("GatherV2").SetUnusedInputs({i3}).SetBody(BinopGatherCommon); -REG_BPROP_BUILDER("Fill").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Fill").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto dtype = ib->GetInput(kIndex0); auto dims = ib->GetInput(kIndex1); auto x = ib->GetInput(kIndex2); return {ib->ZerosLike(dtype), ib->ZerosLike(dims), ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) { auto k = ib->GetInput(kIndex1); auto num_rows = ib->GetInput(kIndex2); auto num_cols = ib->GetInput(kIndex3); @@ -934,7 +929,7 @@ REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) return {part, ib->ZerosLike(k), ib->ZerosLike(num_rows), ib->ZerosLike(num_cols), ib->ZerosLike(padding_value)}; }); -REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) { auto align = ib->GetAttr("align"); auto x = ib->GetInput(kIndex0); auto k = ib->GetInput(kIndex1); @@ -949,7 +944,7 @@ REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) return {diag, ib->ZerosLike(k), ib->ZerosLike(padding_value)}; }); -REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i1, i3}).SetBody(BODYFUNC(ib) { auto align = ib->GetAttr("align"); auto diagonal = ib->GetInput(kIndex1); auto k = ib->GetInput(kIndex2); @@ -964,37 +959,37 @@ REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC( return {x_cal, diagonal_cal, ib->ZerosLike(k)}; }); -REG_BPROP_BUILDER("LogNormalReverse").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("LogNormalReverse").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto input_data = ib->GetInput(kIndex0); return {ib->ZerosLike(input_data)}; }); -REG_BPROP_BUILDER("Shape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Shape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("Rank").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Rank").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("DynamicShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("DynamicShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("TensorShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("TensorShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("DType").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("DType").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto begin = ib->GetInput(kIndex1); auto end = ib->GetInput(kIndex2); @@ -1010,7 +1005,7 @@ REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { return {dx, ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)}; }); -REG_BPROP_BUILDER("MaskedFill").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MaskedFill").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { auto input_data = ib->GetInput(kIndex0); auto mask = ib->GetInput(kIndex1); auto value = ib->GetInput(kIndex2); @@ -1066,7 +1061,7 @@ REG_BPROP_BUILDER("IdentityN").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { return {dout}; }); -REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto align_corners = GetValue(ib->GetAttr("align_corners")); auto half_pixel_centers = GetValue(ib->GetAttr("half_pixel_centers")); auto data_format = GetValue(ib->GetAttr("format")); @@ -1106,7 +1101,7 @@ REG_BPROP_BUILDER("SegmentSum").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { return {ib->Cast(ib->Emit("Gather", {dout, segment_ids, ib->Tensor(0)}), dout_type), ib->ZerosLike(segment_ids)}; }); -REG_BPROP_BUILDER("EmbeddingLookup").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("EmbeddingLookup").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto indices = ib->GetInput(kIndex1); auto offset = ib->GetInput(kIndex2); @@ -1147,7 +1142,7 @@ REG_BPROP_BUILDER("SplitV").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto ksizes = GetValue>(ib->GetAttr("kernel_size")); auto dilations = GetValue>(ib->GetAttr("dilation")); auto strides = GetValue>(ib->GetAttr("stride")); @@ -1163,7 +1158,7 @@ REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { return {dx, ib->ZerosLike(output_size)}; }); -REG_BPROP_BUILDER("ExtractVolumePatches").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ExtractVolumePatches").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto ksize = GetValue>(ib->GetAttr("kernel_size")); auto ksize_d = ksize.at(2); auto ksize_h = ksize.at(3); @@ -1213,7 +1208,7 @@ REG_BPROP_BUILDER("ExtractVolumePatches").SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("AffineGrid").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("AffineGrid").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { auto align_corners = GetValue(ib->GetAttr("align_corners")); auto theta = ib->GetInput(kIndex0); auto output_size = GetIntList(ib->GetInput(kIndex1)); @@ -1331,7 +1326,7 @@ REG_BPROP_BUILDER("ScatterAddWithAxis").SetUnusedInputs({i0, i2, i3}).SetBody(BO return {dout, ib->ZerosLike(indices), update_grad}; }); -REG_BPROP_BUILDER("Expand").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Expand").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex3); auto dout_shape = ib->GetShape(dout); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc index 66cd5d41d9f..b8eafcefe63 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc @@ -151,14 +151,14 @@ NodePtrList CommonSparseSegmentBpropForCpu(const BpropIRBuilder *ib, bool with_s } } // namespace REG_BPROP_BUILDERS_BEGIN(GradSparseOps) -REG_BPROP_BUILDER("SparseToDense").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseToDense").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) { auto indices = ib->GetInput(kIndex0); auto dense_shape = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex4); return {ib->ZerosLike(indices), ib->Emit("GatherNd", {dout, indices}), ib->ZerosLike(dense_shape)}; }); -REG_BPROP_BUILDER("SparseToDenseV2").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseToDenseV2").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto indices = ib->GetInput(kIndex0); auto output_shape = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex5); @@ -214,7 +214,7 @@ REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetUnusedInputs({i4}).SetBody(BODYF return {ib->ZerosLike(indices), values_grad, ib->ZerosLike(dense_shape), dense_grad}; }); -REG_BPROP_BUILDER("SparseAdd").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseAdd").SetUnusedInputs({i1, i2, i4, i5, i6}).SetBody(BODYFUNC(ib) { auto x1_indices = ib->GetInput(kIndex0); auto x1_values = ib->GetInput(kIndex1); auto x1_shape = ib->GetInput(kIndex2); @@ -359,19 +359,19 @@ REG_BPROP_BUILDER("CSRDiv").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { dense_grad}; }); -REG_BPROP_BUILDER("CSR2COO").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("CSR2COO").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto indptr = ib->GetInput(kIndex0); auto nnz = ib->GetInput(kIndex1); return {ib->ZerosLike(indptr), ib->ZerosLike(nnz)}; }); -REG_BPROP_BUILDER("COO2CSR").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("COO2CSR").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto row_indices = ib->GetInput(kIndex0); auto height = ib->GetInput(kIndex1); return {ib->ZerosLike(row_indices), ib->ZerosLike(height)}; }); -REG_BPROP_BUILDER("MakeCOOTensor").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MakeCOOTensor").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto indices = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex4); auto dout_values = ib->TupleGetItem(dout, kIndex1); @@ -394,12 +394,12 @@ REG_BPROP_BUILDER("COOTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(i return {ib->MakeTuple({ib->ZerosLike(coo_tensor_indices), dout, coo_tensor_shape})}; }); -REG_BPROP_BUILDER("COOTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("COOTensorGetDenseShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto coo_tensor = ib->GetInput(kIndex0); return {ib->ZerosLike(coo_tensor)}; }); -REG_BPROP_BUILDER("MakeCSRTensor").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MakeCSRTensor").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto indptr = ib->GetInput(kIndex0); auto indices = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex5); @@ -435,7 +435,7 @@ REG_BPROP_BUILDER("CSRTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(i return {ib->MakeTuple({ib->ZerosLike(csr_tensor_indptr), ib->ZerosLike(csr_tensor_indices), dout, csr_tensor_shape})}; }); -REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto csr_tensor = ib->GetInput(kIndex0); return {ib->ZerosLike(csr_tensor)}; }); @@ -453,7 +453,7 @@ REG_BPROP_BUILDER("CSRSparseMatrixToDense").SetUnusedInputs({i5}).SetBody(BODYFU ib->TupleGetItem(res, kIndex3), ib->TupleGetItem(res, kIndex4)}; }); -REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto indices = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); @@ -485,7 +485,7 @@ REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody(BODYFUNC(ib) { return {ib->Emit("CSRSparseMatrixToDense", {shape, batch_ptr, row_ptr, col_ind, dvalue}), ib->ZerosLike(indices)}; }); -REG_BPROP_BUILDER("SparseSoftmax").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseSoftmax").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { auto indices = ib->GetInput(kIndex0); auto values = ib->GetInput(kIndex1); auto shape = ib->GetInput(kIndex2); @@ -518,29 +518,29 @@ REG_BPROP_BUILDER("CSRSparseMatrixToSparseTensor").SetUnusedInputs({i0, i1, i2, ib->TupleGetItem(dx, kIndex3), ib->TupleGetItem(dx, kIndex4)}; }); -REG_BPROP_BUILDER("SparseSegmentSqrtN").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseSegmentSqrtN").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", false); }); -REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", true); }); -REG_BPROP_BUILDER("SparseSegmentSum").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseSegmentSum").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { if (ib->GetTargetFromContext() == kGPUDevice) { return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", false); } return CommonSparseSegmentBpropForCpu(ib, false); }); -REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { if (ib->GetTargetFromContext() == kGPUDevice) { return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", true); } return CommonSparseSegmentBpropForCpu(ib, true); }); -REG_BPROP_BUILDER("SparseTensorDenseAdd").SetUnusedInputs({i1, i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseTensorDenseAdd").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto x1_indices = ib->GetInput(kIndex0); auto x1_shape = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex5);