diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc index da1d058e2b5..a2be9415b27 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc @@ -174,7 +174,7 @@ REG_BPROP_BUILDER("ReLU").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("TopK").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("TopK").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto input_x = ib->GetInput(kIndex0); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); @@ -225,7 +225,7 @@ REG_BPROP_BUILDER("SigmoidCrossEntropyWithLogits").SetUnusedInputs({i2}).SetBody return {dx, ib->ZerosLike(y)}; }); -REG_BPROP_BUILDER("Pad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Pad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto paddings = ib->GetAttr>>("paddings"); auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); @@ -238,7 +238,7 @@ REG_BPROP_BUILDER("Pad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { auto inputs = ib->GetInput(kIndex0); auto rois = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); @@ -450,7 +450,7 @@ REG_BPROP_BUILDER("L2Normalize").SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto labels = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); @@ -481,7 +481,7 @@ REG_BPROP_BUILDER("ResizeBilinear").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("OneHot").SetUnusedInputs({i4, i5}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("OneHot").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) { auto indices = ib->GetInput(kIndex0); auto depth = ib->GetInput(kIndex1); auto on_value = ib->GetInput(kIndex2); @@ -508,7 +508,7 @@ REG_BPROP_BUILDER("L2Loss").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("RNNTLoss").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("RNNTLoss").SetUnusedInputs({i0, i1, i2, i3, i5}).SetBody(BODYFUNC(ib) { auto labels = ib->GetInput(kIndex1); auto act_lens = ib->GetInput(kIndex2); auto label_lens = ib->GetInput(kIndex3); @@ -678,7 +678,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) return {dx1, dx2, dgrad}; }); -REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto input_x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto dx = ib->Emit("UpsampleNearest3DGrad", {dout}, @@ -688,7 +688,7 @@ REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib return {dx}; }); -REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto dx = ib->Emit("UpsampleTrilinear3DGrad", {dout}, @@ -702,7 +702,7 @@ REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i1}).SetBody(BODYFUNC( REG_BPROP_BUILDER("Dropout2D").SetBody(Dropout2DBpropExpander); REG_BPROP_BUILDER("Dropout3D").SetBody(Dropout2DBpropExpander); -REG_BPROP_BUILDER("CTCLoss").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("CTCLoss").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto labels_indices = ib->GetInput(kIndex1); auto labels_values = ib->GetInput(kIndex2); auto sequence_length = ib->GetInput(kIndex3); @@ -772,7 +772,7 @@ REG_BPROP_BUILDER("AvgPool").SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto x_shape = ib->GetShape(x); @@ -822,7 +822,7 @@ REG_BPROP_BUILDER("ReLUV2").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto data_format = GetValue(ib->GetAttr("format")); auto dy = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); @@ -850,7 +850,7 @@ REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { return {tiled_grad}; }); -REG_BPROP_BUILDER("ExtractImagePatches").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ExtractImagePatches").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto ksizes_row = GetValue>(ib->GetAttr("ksizes"))[2]; auto ksizes_col = GetValue>(ib->GetAttr("ksizes"))[3]; auto x = ib->GetInput(kIndex0); @@ -957,7 +957,7 @@ REG_BPROP_BUILDER("Softsign").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("Tanh").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Tanh").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto out = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex2); @@ -988,7 +988,7 @@ REG_BPROP_BUILDER("Gelu").SetBody(GeLUBpropExpander); REG_BPROP_BUILDER("FastGeLU").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander); REG_BPROP_BUILDER("FastGelu").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander); -REG_BPROP_BUILDER("InstanceNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("InstanceNorm").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto gamma = ib->GetInput(kIndex1); auto mean = ib->GetInput(kIndex3); @@ -1032,7 +1032,7 @@ REG_BPROP_BUILDER("BatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {dx, dscale, dbias, ib->ZerosLike(mean), ib->ZerosLike(variance)}; }); -REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i6}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i5, i6}).SetBody(BODYFUNC(ib) { auto dy = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex1); auto scale = ib->GetInput(kIndex2); @@ -1052,7 +1052,7 @@ REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i6}).SetBody(BODYFUNC(ib) { return {ddy, dx, dscale, ib->ZerosLike(mean), ib->ZerosLike(variance), ib->ZerosLike(reserve)}; }); -REG_BPROP_BUILDER("Softmax").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Softmax").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { auto axis = GetValue>(ib->GetAttr("axis")); auto one_axis = axis[0]; auto x = ib->GetInput(kIndex0); @@ -1337,7 +1337,7 @@ REG_BPROP_BUILDER("MultiMarginLoss").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) return {dx, ib->ZerosLike(target), ib->ZerosLike(weight)}; }); -REG_BPROP_BUILDER("DropoutGenMask").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("DropoutGenMask").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto shape = ib->GetInput(kIndex0); auto keep_prob = ib->GetInput(kIndex1); return {ib->ZerosLike(shape), ib->ZerosLike(keep_prob)}; @@ -1350,7 +1350,7 @@ REG_BPROP_BUILDER("DropoutDoMask").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib return {ib->Emit("DropoutDoMask", {dout, y, keep_prob}), ib->ZerosLike(y), ib->ZerosLike(keep_prob)}; }); -REG_BPROP_BUILDER("ReluGrad").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ReluGrad").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); auto dgrad = ib->Emit("ReluGrad", {dout, y}); @@ -1390,7 +1390,7 @@ REG_BPROP_BUILDER("GridSampler2D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {dx, dgrid}; }); -REG_BPROP_BUILDER("ResizeLinear1D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("ResizeLinear1D").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { auto input_x = ib->GetInput(kIndex0); auto size = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); @@ -1441,7 +1441,7 @@ REG_BPROP_BUILDER("MaxUnpool3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {dx, dargmax}; }); -REG_BPROP_BUILDER("NthElement").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("NthElement").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { auto input_x = ib->GetInput(kIndex0); auto n = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); @@ -1452,7 +1452,7 @@ REG_BPROP_BUILDER("NthElement").SetBody(BODYFUNC(ib) { return {ib->Mul(ib->Div(indicators, num_select), dout), ib->ZerosLike(n)}; }); -REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto x_shape = ib->Tensor(ib->GetShape(x)); @@ -1460,7 +1460,7 @@ REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib return {dx}; }); -REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto dx = ib->Emit("AdaptiveAvgPool2DGradV1", {dout}, {{"orig_input_shape", MakeValue(ib->GetShape(x))}}); @@ -1479,7 +1479,7 @@ REG_BPROP_BUILDER("FractionalMaxPool").SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto random_samples = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); @@ -1489,7 +1489,7 @@ REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody(BODYFUNC(ib) { return {dx, ib->ZerosLike(random_samples)}; }); -REG_BPROP_BUILDER("FractionalAvgPool").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("FractionalAvgPool").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto out = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex2); @@ -1503,7 +1503,7 @@ REG_BPROP_BUILDER("FractionalAvgPool").SetBody(BODYFUNC(ib) { return {dx}; }); -REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { auto spatial_scale = ib->GetAttr("spatial_scale"); auto group_size = ib->GetAttr("group_size"); auto output_dim = ib->GetAttr("output_dim"); @@ -1525,7 +1525,7 @@ REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {dx, ib->ZerosLike(rois)}; }); -REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); auto orig_input_shape = ib->Value(ib->GetShape(x)); @@ -1589,7 +1589,7 @@ REG_BPROP_BUILDER("InstanceNormV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { return {dx, dgamma, dbeta, ib->ZerosLike(mean), ib->ZerosLike(variance)}; }); -REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto random_samples = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); @@ -1609,7 +1609,7 @@ REG_BPROP_BUILDER("AdaptiveAvgPool2D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib return {dx}; }); -REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { auto logits = ib->GetInput(kIndex0); auto labels = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex2); @@ -1658,8 +1658,7 @@ REG_BPROP_BUILDER("DepthwiseConv2dNative").SetUnusedInputs({i2}).SetBody(BODYFUN return {dx, dw}; }); -REG_BPROP_BUILDER("PadV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { - auto x = ib->GetInput(kIndex0); +REG_BPROP_BUILDER("PadV3").SetUnusedInputs({i0, i1, i3}).SetBody(BODYFUNC(ib) { auto paddings = ib->GetInput(kIndex1); auto constant_values = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex4); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_other_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_other_ops.cc index 9e765a8b510..72e5f849c8e 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_other_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_other_ops.cc @@ -20,24 +20,24 @@ namespace mindspore::expander::bprop { REG_BPROP_BUILDERS_BEGIN(GradOtherOps) -REG_BPROP_BUILDER("Assign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("Assign").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto y = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); return {dout, ib->ZerosLike(y)}; }); -REG_BPROP_BUILDER("InvertPermutation").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("InvertPermutation").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("IOU").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("IOU").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto y = ib->GetInput(kIndex1); return {ib->ZerosLike(x), ib->ZerosLike(y)}; }); -REG_BPROP_BUILDER("SyncBatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("SyncBatchNorm").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto scale = ib->GetInput(kIndex1); auto mean = ib->GetInput(kIndex3); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc index a912e15976b..d3c14d2fae5 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc @@ -20,26 +20,26 @@ namespace mindspore::expander::bprop { REG_BPROP_BUILDERS_BEGIN(GradQuantOps) -REG_BPROP_BUILDER("BNTrainingReduce").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("BNTrainingReduce").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); return {ib->ZerosLike(x)}; }); -REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto x_min = ib->GetInput(kIndex1); auto x_max = ib->GetInput(kIndex2); return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)}; }); -REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto x_min = ib->GetInput(kIndex1); auto x_max = ib->GetInput(kIndex2); return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)}; }); -REG_BPROP_BUILDER("WtsARQ").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("WtsARQ").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) { auto w_min = ib->GetInput(kIndex1); auto w_max = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex4); @@ -96,7 +96,7 @@ REG_BPROP_BUILDER("FakeQuantPerChannel").SetUnusedInputs({i3}).SetBody(BODYFUNC( return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)}; }); -REG_BPROP_BUILDER("BatchNormFold").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("BatchNormFold").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto mean = ib->GetInput(kIndex1); auto variance = ib->GetInput(kIndex2); @@ -155,7 +155,7 @@ REG_BPROP_BUILDER("BatchNormFold2").SetUnusedInputs({i1, i8}).SetBody(BODYFUNC(i ib->ZerosLike(global_step)}; }); -REG_BPROP_BUILDER("BatchNormFoldD").SetBody(BODYFUNC(ib) { +REG_BPROP_BUILDER("BatchNormFoldD").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto x_sum = ib->GetInput(kIndex1); auto x_square_sum = ib->GetInput(kIndex2);