!47477 Bprop Unused Inputs Removement

Merge pull request !47477 from jiaoy1224/unused
This commit is contained in:
i-robot 2023-01-04 06:59:40 +00:00 committed by Gitee
commit adafe66daa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 39 additions and 40 deletions

View File

@ -174,7 +174,7 @@ REG_BPROP_BUILDER("ReLU").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("TopK").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("TopK").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto input_x = ib->GetInput(kIndex0); auto input_x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex2); auto out = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
@ -225,7 +225,7 @@ REG_BPROP_BUILDER("SigmoidCrossEntropyWithLogits").SetUnusedInputs({i2}).SetBody
return {dx, ib->ZerosLike(y)}; return {dx, ib->ZerosLike(y)};
}); });
REG_BPROP_BUILDER("Pad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Pad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto paddings = ib->GetAttr<std::vector<std::vector<int64_t>>>("paddings"); auto paddings = ib->GetAttr<std::vector<std::vector<int64_t>>>("paddings");
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
@ -238,7 +238,7 @@ REG_BPROP_BUILDER("Pad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto inputs = ib->GetInput(kIndex0); auto inputs = ib->GetInput(kIndex0);
auto rois = ib->GetInput(kIndex1); auto rois = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
@ -450,7 +450,7 @@ REG_BPROP_BUILDER("L2Normalize").SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto labels = ib->GetInput(kIndex1); auto labels = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2); auto out = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
@ -481,7 +481,7 @@ REG_BPROP_BUILDER("ResizeBilinear").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("OneHot").SetUnusedInputs({i4, i5}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("OneHot").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
auto indices = ib->GetInput(kIndex0); auto indices = ib->GetInput(kIndex0);
auto depth = ib->GetInput(kIndex1); auto depth = ib->GetInput(kIndex1);
auto on_value = ib->GetInput(kIndex2); auto on_value = ib->GetInput(kIndex2);
@ -508,7 +508,7 @@ REG_BPROP_BUILDER("L2Loss").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("RNNTLoss").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("RNNTLoss").SetUnusedInputs({i0, i1, i2, i3, i5}).SetBody(BODYFUNC(ib) {
auto labels = ib->GetInput(kIndex1); auto labels = ib->GetInput(kIndex1);
auto act_lens = ib->GetInput(kIndex2); auto act_lens = ib->GetInput(kIndex2);
auto label_lens = ib->GetInput(kIndex3); auto label_lens = ib->GetInput(kIndex3);
@ -678,7 +678,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib)
return {dx1, dx2, dgrad}; return {dx1, dx2, dgrad};
}); });
REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto input_x = ib->GetInput(kIndex0); auto input_x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto dx = ib->Emit("UpsampleNearest3DGrad", {dout}, auto dx = ib->Emit("UpsampleNearest3DGrad", {dout},
@ -688,7 +688,7 @@ REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto dx = ib->Emit("UpsampleTrilinear3DGrad", {dout}, auto dx = ib->Emit("UpsampleTrilinear3DGrad", {dout},
@ -702,7 +702,7 @@ REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(
REG_BPROP_BUILDER("Dropout2D").SetBody(Dropout2DBpropExpander); REG_BPROP_BUILDER("Dropout2D").SetBody(Dropout2DBpropExpander);
REG_BPROP_BUILDER("Dropout3D").SetBody(Dropout2DBpropExpander); REG_BPROP_BUILDER("Dropout3D").SetBody(Dropout2DBpropExpander);
REG_BPROP_BUILDER("CTCLoss").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("CTCLoss").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto labels_indices = ib->GetInput(kIndex1); auto labels_indices = ib->GetInput(kIndex1);
auto labels_values = ib->GetInput(kIndex2); auto labels_values = ib->GetInput(kIndex2);
auto sequence_length = ib->GetInput(kIndex3); auto sequence_length = ib->GetInput(kIndex3);
@ -772,7 +772,7 @@ REG_BPROP_BUILDER("AvgPool").SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto x_shape = ib->GetShape(x); auto x_shape = ib->GetShape(x);
@ -822,7 +822,7 @@ REG_BPROP_BUILDER("ReLUV2").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto data_format = GetValue<std::string>(ib->GetAttr("format")); auto data_format = GetValue<std::string>(ib->GetAttr("format"));
auto dy = ib->GetInput(kIndex0); auto dy = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
@ -850,7 +850,7 @@ REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
return {tiled_grad}; return {tiled_grad};
}); });
REG_BPROP_BUILDER("ExtractImagePatches").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("ExtractImagePatches").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto ksizes_row = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[2]; auto ksizes_row = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[2];
auto ksizes_col = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[3]; auto ksizes_col = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[3];
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
@ -957,7 +957,7 @@ REG_BPROP_BUILDER("Softsign").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("Tanh").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Tanh").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
@ -988,7 +988,7 @@ REG_BPROP_BUILDER("Gelu").SetBody(GeLUBpropExpander);
REG_BPROP_BUILDER("FastGeLU").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander); REG_BPROP_BUILDER("FastGeLU").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander);
REG_BPROP_BUILDER("FastGelu").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander); REG_BPROP_BUILDER("FastGelu").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander);
REG_BPROP_BUILDER("InstanceNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("InstanceNorm").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto gamma = ib->GetInput(kIndex1); auto gamma = ib->GetInput(kIndex1);
auto mean = ib->GetInput(kIndex3); auto mean = ib->GetInput(kIndex3);
@ -1032,7 +1032,7 @@ REG_BPROP_BUILDER("BatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {dx, dscale, dbias, ib->ZerosLike(mean), ib->ZerosLike(variance)}; return {dx, dscale, dbias, ib->ZerosLike(mean), ib->ZerosLike(variance)};
}); });
REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i6}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i5, i6}).SetBody(BODYFUNC(ib) {
auto dy = ib->GetInput(kIndex0); auto dy = ib->GetInput(kIndex0);
auto x = ib->GetInput(kIndex1); auto x = ib->GetInput(kIndex1);
auto scale = ib->GetInput(kIndex2); auto scale = ib->GetInput(kIndex2);
@ -1052,7 +1052,7 @@ REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i6}).SetBody(BODYFUNC(ib) {
return {ddy, dx, dscale, ib->ZerosLike(mean), ib->ZerosLike(variance), ib->ZerosLike(reserve)}; return {ddy, dx, dscale, ib->ZerosLike(mean), ib->ZerosLike(variance), ib->ZerosLike(reserve)};
}); });
REG_BPROP_BUILDER("Softmax").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Softmax").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
auto axis = GetValue<std::vector<int64_t>>(ib->GetAttr("axis")); auto axis = GetValue<std::vector<int64_t>>(ib->GetAttr("axis"));
auto one_axis = axis[0]; auto one_axis = axis[0];
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
@ -1337,7 +1337,7 @@ REG_BPROP_BUILDER("MultiMarginLoss").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib)
return {dx, ib->ZerosLike(target), ib->ZerosLike(weight)}; return {dx, ib->ZerosLike(target), ib->ZerosLike(weight)};
}); });
REG_BPROP_BUILDER("DropoutGenMask").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("DropoutGenMask").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto shape = ib->GetInput(kIndex0); auto shape = ib->GetInput(kIndex0);
auto keep_prob = ib->GetInput(kIndex1); auto keep_prob = ib->GetInput(kIndex1);
return {ib->ZerosLike(shape), ib->ZerosLike(keep_prob)}; return {ib->ZerosLike(shape), ib->ZerosLike(keep_prob)};
@ -1350,7 +1350,7 @@ REG_BPROP_BUILDER("DropoutDoMask").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib
return {ib->Emit("DropoutDoMask", {dout, y, keep_prob}), ib->ZerosLike(y), ib->ZerosLike(keep_prob)}; return {ib->Emit("DropoutDoMask", {dout, y, keep_prob}), ib->ZerosLike(y), ib->ZerosLike(keep_prob)};
}); });
REG_BPROP_BUILDER("ReluGrad").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("ReluGrad").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto y = ib->GetInput(kIndex1); auto y = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
auto dgrad = ib->Emit("ReluGrad", {dout, y}); auto dgrad = ib->Emit("ReluGrad", {dout, y});
@ -1390,7 +1390,7 @@ REG_BPROP_BUILDER("GridSampler2D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {dx, dgrid}; return {dx, dgrid};
}); });
REG_BPROP_BUILDER("ResizeLinear1D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("ResizeLinear1D").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
auto input_x = ib->GetInput(kIndex0); auto input_x = ib->GetInput(kIndex0);
auto size = ib->GetInput(kIndex1); auto size = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
@ -1441,7 +1441,7 @@ REG_BPROP_BUILDER("MaxUnpool3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {dx, dargmax}; return {dx, dargmax};
}); });
REG_BPROP_BUILDER("NthElement").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("NthElement").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
auto input_x = ib->GetInput(kIndex0); auto input_x = ib->GetInput(kIndex0);
auto n = ib->GetInput(kIndex1); auto n = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2); auto out = ib->GetInput(kIndex2);
@ -1452,7 +1452,7 @@ REG_BPROP_BUILDER("NthElement").SetBody(BODYFUNC(ib) {
return {ib->Mul(ib->Div(indicators, num_select), dout), ib->ZerosLike(n)}; return {ib->Mul(ib->Div(indicators, num_select), dout), ib->ZerosLike(n)};
}); });
REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto x_shape = ib->Tensor(ib->GetShape(x)); auto x_shape = ib->Tensor(ib->GetShape(x));
@ -1460,7 +1460,7 @@ REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto dx = ib->Emit("AdaptiveAvgPool2DGradV1", {dout}, {{"orig_input_shape", MakeValue(ib->GetShape(x))}}); auto dx = ib->Emit("AdaptiveAvgPool2DGradV1", {dout}, {{"orig_input_shape", MakeValue(ib->GetShape(x))}});
@ -1479,7 +1479,7 @@ REG_BPROP_BUILDER("FractionalMaxPool").SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto random_samples = ib->GetInput(kIndex1); auto random_samples = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2); auto out = ib->GetInput(kIndex2);
@ -1489,7 +1489,7 @@ REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody(BODYFUNC(ib) {
return {dx, ib->ZerosLike(random_samples)}; return {dx, ib->ZerosLike(random_samples)};
}); });
REG_BPROP_BUILDER("FractionalAvgPool").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("FractionalAvgPool").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
@ -1503,7 +1503,7 @@ REG_BPROP_BUILDER("FractionalAvgPool").SetBody(BODYFUNC(ib) {
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto spatial_scale = ib->GetAttr("spatial_scale"); auto spatial_scale = ib->GetAttr("spatial_scale");
auto group_size = ib->GetAttr("group_size"); auto group_size = ib->GetAttr("group_size");
auto output_dim = ib->GetAttr("output_dim"); auto output_dim = ib->GetAttr("output_dim");
@ -1525,7 +1525,7 @@ REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {dx, ib->ZerosLike(rois)}; return {dx, ib->ZerosLike(rois)};
}); });
REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto orig_input_shape = ib->Value<ShapeVector>(ib->GetShape(x)); auto orig_input_shape = ib->Value<ShapeVector>(ib->GetShape(x));
@ -1589,7 +1589,7 @@ REG_BPROP_BUILDER("InstanceNormV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
return {dx, dgamma, dbeta, ib->ZerosLike(mean), ib->ZerosLike(variance)}; return {dx, dgamma, dbeta, ib->ZerosLike(mean), ib->ZerosLike(variance)};
}); });
REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto random_samples = ib->GetInput(kIndex1); auto random_samples = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2); auto out = ib->GetInput(kIndex2);
@ -1609,7 +1609,7 @@ REG_BPROP_BUILDER("AdaptiveAvgPool2D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib
return {dx}; return {dx};
}); });
REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
auto logits = ib->GetInput(kIndex0); auto logits = ib->GetInput(kIndex0);
auto labels = ib->GetInput(kIndex1); auto labels = ib->GetInput(kIndex1);
auto out = ib->GetInput(kIndex2); auto out = ib->GetInput(kIndex2);
@ -1658,8 +1658,7 @@ REG_BPROP_BUILDER("DepthwiseConv2dNative").SetUnusedInputs({i2}).SetBody(BODYFUN
return {dx, dw}; return {dx, dw};
}); });
REG_BPROP_BUILDER("PadV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("PadV3").SetUnusedInputs({i0, i1, i3}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto paddings = ib->GetInput(kIndex1); auto paddings = ib->GetInput(kIndex1);
auto constant_values = ib->GetInput(kIndex2); auto constant_values = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4); auto dout = ib->GetInput(kIndex4);

View File

@ -20,24 +20,24 @@
namespace mindspore::expander::bprop { namespace mindspore::expander::bprop {
REG_BPROP_BUILDERS_BEGIN(GradOtherOps) REG_BPROP_BUILDERS_BEGIN(GradOtherOps)
REG_BPROP_BUILDER("Assign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Assign").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto y = ib->GetInput(kIndex1); auto y = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
return {dout, ib->ZerosLike(y)}; return {dout, ib->ZerosLike(y)};
}); });
REG_BPROP_BUILDER("InvertPermutation").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("InvertPermutation").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)}; return {ib->ZerosLike(x)};
}); });
REG_BPROP_BUILDER("IOU").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("IOU").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto y = ib->GetInput(kIndex1); auto y = ib->GetInput(kIndex1);
return {ib->ZerosLike(x), ib->ZerosLike(y)}; return {ib->ZerosLike(x), ib->ZerosLike(y)};
}); });
REG_BPROP_BUILDER("SyncBatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("SyncBatchNorm").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto scale = ib->GetInput(kIndex1); auto scale = ib->GetInput(kIndex1);
auto mean = ib->GetInput(kIndex3); auto mean = ib->GetInput(kIndex3);

View File

@ -20,26 +20,26 @@
namespace mindspore::expander::bprop { namespace mindspore::expander::bprop {
REG_BPROP_BUILDERS_BEGIN(GradQuantOps) REG_BPROP_BUILDERS_BEGIN(GradQuantOps)
REG_BPROP_BUILDER("BNTrainingReduce").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("BNTrainingReduce").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
return {ib->ZerosLike(x)}; return {ib->ZerosLike(x)};
}); });
REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto x_min = ib->GetInput(kIndex1); auto x_min = ib->GetInput(kIndex1);
auto x_max = ib->GetInput(kIndex2); auto x_max = ib->GetInput(kIndex2);
return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)}; return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
}); });
REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto x_min = ib->GetInput(kIndex1); auto x_min = ib->GetInput(kIndex1);
auto x_max = ib->GetInput(kIndex2); auto x_max = ib->GetInput(kIndex2);
return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)}; return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
}); });
REG_BPROP_BUILDER("WtsARQ").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("WtsARQ").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto w_min = ib->GetInput(kIndex1); auto w_min = ib->GetInput(kIndex1);
auto w_max = ib->GetInput(kIndex2); auto w_max = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4); auto dout = ib->GetInput(kIndex4);
@ -96,7 +96,7 @@ REG_BPROP_BUILDER("FakeQuantPerChannel").SetUnusedInputs({i3}).SetBody(BODYFUNC(
return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)}; return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
}); });
REG_BPROP_BUILDER("BatchNormFold").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("BatchNormFold").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto mean = ib->GetInput(kIndex1); auto mean = ib->GetInput(kIndex1);
auto variance = ib->GetInput(kIndex2); auto variance = ib->GetInput(kIndex2);
@ -155,7 +155,7 @@ REG_BPROP_BUILDER("BatchNormFold2").SetUnusedInputs({i1, i8}).SetBody(BODYFUNC(i
ib->ZerosLike(global_step)}; ib->ZerosLike(global_step)};
}); });
REG_BPROP_BUILDER("BatchNormFoldD").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("BatchNormFoldD").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto x_sum = ib->GetInput(kIndex1); auto x_sum = ib->GetInput(kIndex1);
auto x_square_sum = ib->GetInput(kIndex2); auto x_square_sum = ib->GetInput(kIndex2);