forked from mindspore-Ecosystem/mindspore
!47477 Bprop Unused Inputs Removement
Merge pull request !47477 from jiaoy1224/unused
This commit is contained in:
commit
adafe66daa
|
@ -174,7 +174,7 @@ REG_BPROP_BUILDER("ReLU").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TopK").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("TopK").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -225,7 +225,7 @@ REG_BPROP_BUILDER("SigmoidCrossEntropyWithLogits").SetUnusedInputs({i2}).SetBody
|
|||
return {dx, ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Pad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Pad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto paddings = ib->GetAttr<std::vector<std::vector<int64_t>>>("paddings");
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -238,7 +238,7 @@ REG_BPROP_BUILDER("Pad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto inputs = ib->GetInput(kIndex0);
|
||||
auto rois = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -450,7 +450,7 @@ REG_BPROP_BUILDER("L2Normalize").SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto labels = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -481,7 +481,7 @@ REG_BPROP_BUILDER("ResizeBilinear").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("OneHot").SetUnusedInputs({i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("OneHot").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto depth = ib->GetInput(kIndex1);
|
||||
auto on_value = ib->GetInput(kIndex2);
|
||||
|
@ -508,7 +508,7 @@ REG_BPROP_BUILDER("L2Loss").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("RNNTLoss").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("RNNTLoss").SetUnusedInputs({i0, i1, i2, i3, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto labels = ib->GetInput(kIndex1);
|
||||
auto act_lens = ib->GetInput(kIndex2);
|
||||
auto label_lens = ib->GetInput(kIndex3);
|
||||
|
@ -678,7 +678,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib)
|
|||
return {dx1, dx2, dgrad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("UpsampleNearest3DGrad", {dout},
|
||||
|
@ -688,7 +688,7 @@ REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("UpsampleTrilinear3DGrad", {dout},
|
||||
|
@ -702,7 +702,7 @@ REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(
|
|||
REG_BPROP_BUILDER("Dropout2D").SetBody(Dropout2DBpropExpander);
|
||||
REG_BPROP_BUILDER("Dropout3D").SetBody(Dropout2DBpropExpander);
|
||||
|
||||
REG_BPROP_BUILDER("CTCLoss").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("CTCLoss").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto labels_indices = ib->GetInput(kIndex1);
|
||||
auto labels_values = ib->GetInput(kIndex2);
|
||||
auto sequence_length = ib->GetInput(kIndex3);
|
||||
|
@ -772,7 +772,7 @@ REG_BPROP_BUILDER("AvgPool").SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_shape = ib->GetShape(x);
|
||||
|
@ -822,7 +822,7 @@ REG_BPROP_BUILDER("ReLUV2").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto data_format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
auto dy = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -850,7 +850,7 @@ REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
|||
return {tiled_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ExtractImagePatches").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ExtractImagePatches").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto ksizes_row = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[2];
|
||||
auto ksizes_col = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[3];
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
|
@ -957,7 +957,7 @@ REG_BPROP_BUILDER("Softsign").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Tanh").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Tanh").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -988,7 +988,7 @@ REG_BPROP_BUILDER("Gelu").SetBody(GeLUBpropExpander);
|
|||
REG_BPROP_BUILDER("FastGeLU").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander);
|
||||
REG_BPROP_BUILDER("FastGelu").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander);
|
||||
|
||||
REG_BPROP_BUILDER("InstanceNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("InstanceNorm").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto gamma = ib->GetInput(kIndex1);
|
||||
auto mean = ib->GetInput(kIndex3);
|
||||
|
@ -1032,7 +1032,7 @@ REG_BPROP_BUILDER("BatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, dscale, dbias, ib->ZerosLike(mean), ib->ZerosLike(variance)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i6}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i5, i6}).SetBody(BODYFUNC(ib) {
|
||||
auto dy = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
auto scale = ib->GetInput(kIndex2);
|
||||
|
@ -1052,7 +1052,7 @@ REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i6}).SetBody(BODYFUNC(ib) {
|
|||
return {ddy, dx, dscale, ib->ZerosLike(mean), ib->ZerosLike(variance), ib->ZerosLike(reserve)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Softmax").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Softmax").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto axis = GetValue<std::vector<int64_t>>(ib->GetAttr("axis"));
|
||||
auto one_axis = axis[0];
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
|
@ -1337,7 +1337,7 @@ REG_BPROP_BUILDER("MultiMarginLoss").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib)
|
|||
return {dx, ib->ZerosLike(target), ib->ZerosLike(weight)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DropoutGenMask").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("DropoutGenMask").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto shape = ib->GetInput(kIndex0);
|
||||
auto keep_prob = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(shape), ib->ZerosLike(keep_prob)};
|
||||
|
@ -1350,7 +1350,7 @@ REG_BPROP_BUILDER("DropoutDoMask").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib
|
|||
return {ib->Emit("DropoutDoMask", {dout, y, keep_prob}), ib->ZerosLike(y), ib->ZerosLike(keep_prob)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReluGrad").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ReluGrad").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dgrad = ib->Emit("ReluGrad", {dout, y});
|
||||
|
@ -1390,7 +1390,7 @@ REG_BPROP_BUILDER("GridSampler2D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, dgrid};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeLinear1D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("ResizeLinear1D").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto size = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1441,7 +1441,7 @@ REG_BPROP_BUILDER("MaxUnpool3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, dargmax};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("NthElement").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("NthElement").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto n = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1452,7 +1452,7 @@ REG_BPROP_BUILDER("NthElement").SetBody(BODYFUNC(ib) {
|
|||
return {ib->Mul(ib->Div(indicators, num_select), dout), ib->ZerosLike(n)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_shape = ib->Tensor(ib->GetShape(x));
|
||||
|
@ -1460,7 +1460,7 @@ REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("AdaptiveAvgPool2DGradV1", {dout}, {{"orig_input_shape", MakeValue(ib->GetShape(x))}});
|
||||
|
@ -1479,7 +1479,7 @@ REG_BPROP_BUILDER("FractionalMaxPool").SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto random_samples = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1489,7 +1489,7 @@ REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody(BODYFUNC(ib) {
|
|||
return {dx, ib->ZerosLike(random_samples)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FractionalAvgPool").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("FractionalAvgPool").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1503,7 +1503,7 @@ REG_BPROP_BUILDER("FractionalAvgPool").SetBody(BODYFUNC(ib) {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto spatial_scale = ib->GetAttr("spatial_scale");
|
||||
auto group_size = ib->GetAttr("group_size");
|
||||
auto output_dim = ib->GetAttr("output_dim");
|
||||
|
@ -1525,7 +1525,7 @@ REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, ib->ZerosLike(rois)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto orig_input_shape = ib->Value<ShapeVector>(ib->GetShape(x));
|
||||
|
@ -1589,7 +1589,7 @@ REG_BPROP_BUILDER("InstanceNormV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
return {dx, dgamma, dbeta, ib->ZerosLike(mean), ib->ZerosLike(variance)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto random_samples = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1609,7 +1609,7 @@ REG_BPROP_BUILDER("AdaptiveAvgPool2D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto logits = ib->GetInput(kIndex0);
|
||||
auto labels = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1658,8 +1658,7 @@ REG_BPROP_BUILDER("DepthwiseConv2dNative").SetUnusedInputs({i2}).SetBody(BODYFUN
|
|||
return {dx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("PadV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
REG_BPROP_BUILDER("PadV3").SetUnusedInputs({i0, i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto paddings = ib->GetInput(kIndex1);
|
||||
auto constant_values = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
|
|
@ -20,24 +20,24 @@
|
|||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDERS_BEGIN(GradOtherOps)
|
||||
REG_BPROP_BUILDER("Assign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("Assign").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
return {dout, ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("InvertPermutation").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("InvertPermutation").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("IOU").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("IOU").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SyncBatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("SyncBatchNorm").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto scale = ib->GetInput(kIndex1);
|
||||
auto mean = ib->GetInput(kIndex3);
|
||||
|
|
|
@ -20,26 +20,26 @@
|
|||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDERS_BEGIN(GradQuantOps)
|
||||
REG_BPROP_BUILDER("BNTrainingReduce").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("BNTrainingReduce").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("WtsARQ").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("WtsARQ").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto w_min = ib->GetInput(kIndex1);
|
||||
auto w_max = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -96,7 +96,7 @@ REG_BPROP_BUILDER("FakeQuantPerChannel").SetUnusedInputs({i3}).SetBody(BODYFUNC(
|
|||
return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormFold").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("BatchNormFold").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto mean = ib->GetInput(kIndex1);
|
||||
auto variance = ib->GetInput(kIndex2);
|
||||
|
@ -155,7 +155,7 @@ REG_BPROP_BUILDER("BatchNormFold2").SetUnusedInputs({i1, i8}).SetBody(BODYFUNC(i
|
|||
ib->ZerosLike(global_step)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormFoldD").SetBody(BODYFUNC(ib) {
|
||||
REG_BPROP_BUILDER("BatchNormFoldD").SetUnusedInputs({i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_sum = ib->GetInput(kIndex1);
|
||||
auto x_square_sum = ib->GetInput(kIndex2);
|
||||
|
|
Loading…
Reference in New Issue