add some nn expander bprop

This commit is contained in:
r1chardf1d0 2023-03-05 11:37:45 +08:00
parent d1a4199d05
commit 2049e36c8e
3 changed files with 173 additions and 80 deletions

View File

@ -262,6 +262,11 @@ void RegNNBpropExpanderOps2() {
REGISTER_EXPANDER_BPROP_IMPL(AdaptiveAvgPool3D);
REGISTER_EXPANDER_BPROP_IMPL(FractionalAvgPool);
REGISTER_EXPANDER_BPROP_IMPL(PSROIPooling);
REGISTER_EXPANDER_BPROP_IMPL(BiasAddGrad);
REGISTER_EXPANDER_BPROP_IMPL(MaxPoolGrad);
REGISTER_EXPANDER_BPROP_IMPL(TopK);
REGISTER_EXPANDER_BPROP_IMPL(BCEWithLogitsLoss);
REGISTER_EXPANDER_BPROP_IMPL(KLDivLoss);
}
void RegArrayBpropExpanderOps1() {
@ -321,7 +326,9 @@ void RegArrayBpropExpanderOps2() {}
void RegClipBpropExpanderOps() {}
void RegCommBpropExpanderOps() {}
void RegInnerBpropExpanderOps() {}
void RegOtherBpropExpanderOps() {}
void RegOtherBpropExpanderOps() { REGISTER_EXPANDER_BPROP_IMPL(Assign); }
void RegQuantBpropExpanderOps() {}
void RegSparseBpropExpanderOps() {}

View File

@ -20,6 +20,7 @@
#include <vector>
#include <string>
#include <map>
#include <unordered_set>
#include <functional>
#include "include/common/expander/core/node.h"
@ -56,12 +57,39 @@ class BpropIRBuilder : public Emitter {
// For node that has single output
ShapeVector GetShape(const NodePtr &node) const { return node->shape(); }
NodePtr Shape(const NodePtr &node, bool tensor = false) const {
auto shape = GetShape(node);
if (tensor) {
return IsDynamic(shape) ? Emit("TensorShape", {node}) : Tensor(shape);
} else {
return IsDynamic(shape) ? Emit("Shape", {node}) : Value<ShapeVector>(shape);
}
}
// For node that has multiple outputs
std::vector<ShapeVector> GetShapes(const NodePtr &node) const { return node->shapes(); }
TypePtr GetDtype(const NodePtr &node) const { return node->dtype(); }
TypeId GetDtypeId(const NodePtr &node) const { return GetDtype(node)->type_id(); }
ValuePtr GetAttr(const NodePtr &node, const std::string &attr) const;
int64_t GetSize(const NodePtr &node) const;
NodePtr DynSize(const NodePtr &node, const TypePtr &type) const { return Cast(DynSize(node), type); }
NodePtr DynSize(const NodePtr &node, TypeId type_id) const { return Cast(DynSize(node), type_id); }
NodePtr DynSize(const NodePtr &node) const {
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
auto shape = inputs.at(0);
int64_t size = 1;
for (auto &i : shape) {
size *= i;
}
return {{size}};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector { return {1}; };
return ShapeCalc({node}, shape_func, infer_func, {})[0];
}
NodePtr Range(const NodePtr &limit) const { return Range(Tensor(0, kInt64), limit, Tensor(1, kInt64)); }
NodePtr Range(const NodePtr &start, const NodePtr &limit, const NodePtr &delta, int64_t max_len = 1000000) const {
return Emit("Range", {start, limit, delta}, {{"maxlen", MakeValue(max_len)}});
}
std::string name() const { return name_; }
std::string GetTargetFromContext() const;

View File

@ -71,8 +71,7 @@ NodePtrList Conv2DTransposeBpropExpander(const BpropIRBuilder *ib) {
auto w = ib->GetInput(kIndex1);
auto f_sizes = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4);
auto shp = ib->GetShape(w);
auto w_shape = IsDynamic(shp) ? ib->Emit("Shape", {w}) : ib->Value<ShapeVector>(shp);
auto w_shape = ib->Shape(w);
auto dx = ib->Emit(kConv2DOpName, {dout, w},
{{"pad_mode", ib->GetAttr("pad_mode")},
{"pad", ib->GetAttr("pad")},
@ -196,33 +195,51 @@ REG_BPROP_BUILDER("TopK").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto input_x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex3);
auto indices = ib->TupleGetItem(out, kIndex1);
auto dout0 = ib->TupleGetItem(dout, kIndex0);
auto in_shape = ib->GetShape(input_x);
auto in_lastdim = in_shape.back();
auto ind_shape = ib->GetShape(indices);
auto ind_lastdim = ind_shape.back();
auto ind_2d = ib->Reshape(indices, {-1, ind_lastdim});
auto outerdim = ib->GetShape(ind_2d)[0]; // k
// [0, outerdim, 2*outerdim, ..., (k-1)*outerdim]
auto indices_dtype = ib->GetDtype(indices);
std::vector<int64_t> range_flatten_index_vec(LongToSize(outerdim));
for (int64_t i = 0; i < outerdim; i++) {
range_flatten_index_vec[i] = i * in_lastdim;
if (IsDynamic(in_shape)) {
auto shape_func0 = [](const ShapeArray &inputs) -> ShapeArray { return {{-1, inputs.at(0).back()}}; };
auto infer_func0 = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector { return {2}; };
auto re0 = ib->ShapeCalc({indices}, shape_func0, infer_func0, {})[0];
NodePtr ind_2d = ib->Reshape(indices, re0);
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
auto in_shape = inputs.at(0);
auto in_lastdim = in_shape.back();
auto outerdim = inputs.at(1)[0]; // k
auto in_shape_1d_x =
ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies<int64_t>()));
return {in_shape_1d_x, {outerdim * in_lastdim}, {in_lastdim}};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector {
return {1, 1, 1};
};
auto res = ib->ShapeCalc({input_x, ind_2d}, shape_func, infer_func, {});
auto in_shape_1d = res[0];
auto range_flatten_index = ib->Range(ib->Tensor(0, kInt64), res[1], res[2]);
auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), in_shape_1d});
out_grad = ib->Reshape(out_grad, ib->Emit("TensorShape", {input_x}));
auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1));
return {out_grad, grad_k};
} else {
auto ind_lastdim = ib->GetShape(indices).back();
auto ind_2d = ib->Reshape(indices, {-1, ind_lastdim});
auto in_lastdim = in_shape.back();
auto outerdim = ib->GetShape(ind_2d)[0]; // k
std::vector<int64_t> range_flatten_index_vec(LongToSize(outerdim));
for (int64_t i = 0; i < outerdim; i++) {
range_flatten_index_vec[i] = i * in_lastdim;
}
auto range_flatten_index = ib->Tensor(range_flatten_index_vec, ib->GetDtype(indices));
auto in_shape_1d =
ib->Value(ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies<int64_t>())));
auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), in_shape_1d});
out_grad = ib->Reshape(out_grad, in_shape);
auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1));
return {out_grad, grad_k};
}
auto range_flatten_index = ib->Tensor(range_flatten_index_vec, indices_dtype);
auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
auto in_shape_1d = ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies<int64_t>()));
auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), ib->Value(in_shape_1d)});
out_grad = ib->Reshape(out_grad, in_shape);
auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1));
return {out_grad, grad_k};
});
REG_BPROP_BUILDER("PReLU").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
@ -251,8 +268,7 @@ REG_BPROP_BUILDER("Pad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
for (const auto &item : paddings) {
begin.push_back(item.at(0));
}
auto shp = ib->GetShape(x);
auto x_shape = IsDynamic(shp) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(shp);
auto x_shape = ib->Shape(x);
auto dx = ib->Emit("Slice", {dout, ib->EmitValue(MakeValue(begin)), x_shape});
return {dx};
});
@ -262,7 +278,7 @@ REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto rois = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
auto shp = ib->GetShape(inputs);
auto inputs_shape = IsDynamic(shp) ? ib->Emit("Shape", {inputs}) : ib->Value<ShapeVector>(shp);
auto inputs_shape = ib->Shape(inputs);
auto dx = ib->Emit("ROIAlignGrad", {dout, rois, inputs_shape},
{{"pooled_height", ib->GetAttr("pooled_height")},
{"pooled_width", ib->GetAttr("pooled_width")},
@ -543,10 +559,8 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto w = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
auto x_sh = ib->GetShape(x);
auto w_sh = ib->GetShape(w);
auto x_shape = IsDynamic(x_sh) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(x_sh);
auto w_shape = IsDynamic(w_sh) ? ib->Emit("Shape", {w}) : ib->Value<ShapeVector>(w_sh);
auto x_shape = ib->Shape(x);
auto w_shape = ib->Shape(w);
auto dx = ib->Emit("Conv3DBackpropInput", {w, dout, x_shape},
{{"pad_mode", ib->GetAttr("pad_mode")},
{"pad", ib->GetAttr("pad")},
@ -559,7 +573,7 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
{"format", ib->GetAttr("format")},
{"out_channel", ib->GetAttr("out_channel")},
{"kernel_size", ib->GetAttr("kernel_size")},
{"input_size", MakeValue(x_sh)},
{"input_size", MakeValue(ib->GetShape(x))},
{"mode", ib->GetAttr("mode")}});
auto dw = ib->Emit("Conv3DBackpropFilter", {x, dout, w_shape},
{{"pad_mode", ib->GetAttr("pad_mode")},
@ -573,7 +587,7 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
{"format", ib->GetAttr("format")},
{"out_channel", ib->GetAttr("out_channel")},
{"kernel_size", ib->GetAttr("kernel_size")},
{"filter_size", MakeValue(w_sh)},
{"filter_size", MakeValue(ib->GetShape(w))},
{"mode", ib->GetAttr("mode")}});
dw = ib->Cast(dw, ib->GetDtype(x));
return {dx, dw};
@ -587,8 +601,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib)
auto x = ib->GetInput(kIndex0);
auto w = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
auto w_sh = ib->GetShape(w);
auto w_shape = IsDynamic(w_sh) ? ib->Emit("Shape", {w}) : ib->Value<ShapeVector>(w_sh);
auto w_shape = ib->Shape(w);
auto dx = ib->Emit("Conv3D", {dout, w},
{{"out_channel", ib->GetAttr("in_channel")},
{"kernel_size", ib->GetAttr("kernel_size")},
@ -607,7 +620,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib)
auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, w_shape},
{{"out_channel", ib->GetAttr("in_channel")},
{"kernel_size", ib->GetAttr("kernel_size")},
{"filter_size", MakeValue(w_sh)},
{"filter_size", MakeValue(ib->GetShape(w))},
{"mode", ib->GetAttr("mode")},
{"pad_mode", MakeValue("pad")},
{"pad", ib->GetAttr("pad_list")},
@ -683,11 +696,6 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib)
{"format", MakeValue("NCHW")},
{"pad_mode", ib->GetAttr("pad_mode")}});
} else {
auto x2_shape = ib->GetShape(x2);
auto b = x2_shape.at(0);
auto c = x2_shape.at(1);
auto h = x2_shape.at(2);
auto w = x2_shape.at(3);
auto tmp = ib->Emit("MaxPoolWithArgmax", {x1},
{{"kernel_size", MakeValue(kernel_size)},
{"strides", MakeValue(strides)},
@ -695,11 +703,39 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib)
{"data_format", MakeValue("NCHW")},
{"format", MakeValue("NCHW")}});
auto ind = ib->TupleGetItem(tmp, 1);
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
batch = ib->Tile(ib->Reshape(batch, {-1, 1}), {1, (c * h) * w});
auto gather_ind =
ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, {b, -1})})}, {{"axis", MakeValue<int64_t>(-1)}});
dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, {b, -1}), gather_ind}), {b, c, h, w});
auto x2_shape = ib->GetShape(x2);
if (IsDynamic(x2_shape)) {
auto shape = ib->Emit("Shape", {x2});
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
auto x2_shape = inputs.at(0);
auto b = x2_shape.at(0);
auto c = x2_shape.at(1);
auto h = x2_shape.at(2);
auto w = x2_shape.at(3);
return {{b}, {b, -1}, {1, c * h * w}};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector {
return {1, 2, 2};
};
auto res = ib->ShapeCalc({x2}, shape_func, infer_func, {});
auto batch = ib->Cast(ib->Range(res[0]), kInt32);
batch = ib->Tile(ib->Reshape(batch, {-1, 1}), res[2]);
auto gather_ind =
ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, res[1])})}, {{"axis", MakeValue<int64_t>(-1)}});
dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, res[1]), gather_ind}), shape);
} else {
auto b = x2_shape.at(0);
auto c = x2_shape.at(1);
auto h = x2_shape.at(2);
auto w = x2_shape.at(3);
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
batch = ib->Tile(ib->Reshape(batch, {-1, 1}), {1, (c * h) * w});
auto gather_ind =
ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, {b, -1})})}, {{"axis", MakeValue<int64_t>(-1)}});
dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, {b, -1}), gather_ind}), {b, c, h, w});
}
}
return {dx1, dx2, dgrad};
});
@ -801,11 +837,10 @@ REG_BPROP_BUILDER("AvgPool").SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto x_sh = ib->GetShape(x);
auto x_shape = IsDynamic(x_sh) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(x_sh);
auto x_shape = ib->Shape(x);
auto dx = ib->Emit("AvgPool3DGrad", {x_shape, dout},
{{"kernel_size", ib->GetAttr("kernel_size")},
{"origin_input_shape", MakeValue(x_sh)},
{"origin_input_shape", MakeValue(ib->GetShape(x))},
{"strides", ib->GetAttr("strides")},
{"pad_list", ib->GetAttr("pad_list")},
{"ceil_mode", ib->GetAttr("ceil_mode")},
@ -853,25 +888,40 @@ REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib)
auto data_format = GetValue<std::string>(ib->GetAttr("format"));
auto dy = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto dy_shape = ib->GetShape(dy);
auto dout_shape = ib->GetShape(dout);
ShapeVector expanded_shape;
ShapeVector tile_mults;
ShapeVector one_vec{1};
if (data_format == "NCHW") {
// expanded_shape = np.concatenate([np.ones_like(shape[:1]), bias_shape, np.ones_like(shape[2:])], axis=0)
expanded_shape = one_vec + dout_shape;
expanded_shape = dy_shape.size() > 2 ? expanded_shape + ShapeVector(1, dy_shape.size() - 2) : expanded_shape;
// tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0)
ShapeVector tmp{dy_shape[0], 1};
tile_mults = tmp;
tile_mults = dy_shape.size() > 2 ? tile_mults + ShapeVector(dy_shape.begin() + 2, dy_shape.end()) : tile_mults;
} else {
// expanded_shape = np.concatenate([np.ones_like(shape[:-1]), bias_shape], axis=0)
expanded_shape = ShapeVector(1, dy_shape.size() - 1) + dout_shape;
// tile_mults = np.concatenate([shape[:-1], [1]], axis=0)
tile_mults = ShapeVector(dy_shape.begin(), dy_shape.end() - 1) + one_vec;
}
auto shape_func = [data_format](const ShapeArray &inputs) -> ShapeArray {
ShapeVector expanded_shape;
ShapeVector tile_mults;
ShapeVector one_vec{1};
auto dy_shape = inputs.at(0);
auto dout_shape = inputs.at(1);
if (data_format == "NCHW") {
// expanded_shape = np.concatenate([np.ones_like(shape[:1]), bias_shape, np.ones_like(shape[2:])], axis=0)
expanded_shape = one_vec + dout_shape;
expanded_shape = dy_shape.size() > 2 ? expanded_shape + ShapeVector(1, dy_shape.size() - 2) : expanded_shape;
// tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0)
ShapeVector tmp{dy_shape[0], 1};
tile_mults = tmp;
tile_mults = dy_shape.size() > 2 ? tile_mults + ShapeVector(dy_shape.begin() + 2, dy_shape.end()) : tile_mults;
} else {
// expanded_shape = np.concatenate([np.ones_like(shape[:-1]), bias_shape], axis=0)
expanded_shape = ShapeVector(1, dy_shape.size() - 1) + dout_shape;
// tile_mults = np.concatenate([shape[:-1], [1]], axis=0)
tile_mults = ShapeVector(dy_shape.begin(), dy_shape.end() - 1) + one_vec;
}
return {expanded_shape, tile_mults};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector {
int64_t x_rank =
IsDynamicRank(inputs.at(0)) ? -1 : static_cast<int64_t>(inputs.at(0).size() + inputs.at(1).size() - 1);
int64_t y_rank = IsDynamicRank(inputs.at(1)) ? -1 : static_cast<int64_t>(inputs.at(0).size());
return {x_rank, y_rank};
};
auto res = ib->ShapeCalc({dy, dout}, shape_func, infer_func, {});
NodePtr expanded_shape = res[0];
NodePtr tile_mults = res[1];
auto expanded_grad = ib->Reshape(dout, expanded_shape);
auto tiled_grad = ib->Tile(expanded_grad, tile_mults);
return {tiled_grad};
@ -1103,7 +1153,6 @@ REG_BPROP_BUILDER("Softmax").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
NodePtr reverse_axis =
IsDynamicRank(shp) ? ib->ShapeCalc({x}, shape_func, infer_func)[0] : ib->Value(GetTransposeAxis(shp, one_axis));
out = ib->Transpose(out, reverse_axis);
dout = ib->Transpose(dout, reverse_axis);
auto dx = ib->Mul(out, ib->Sub(dout, ib->ReduceSum(ib->Mul(out, dout), ShapeVector{-1}, true)));
@ -1227,8 +1276,7 @@ REG_BPROP_BUILDER("Conv2DBackpropFilter").SetUnusedInputs({i2, i3}).SetBody(BODY
auto x = ib->GetInput(kIndex1);
auto filter_size = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4);
auto shp = ib->GetShape(x);
auto x_shape = IsDynamic(shp) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(shp);
auto x_shape = ib->Shape(x);
auto dw_dx = ib->Emit(kConv2DBackpropInputOpName, {dy, dout, x_shape},
{{"mode", ib->GetAttr("mode")},
{"dilation", ib->GetAttr("dilation")},
@ -1277,8 +1325,15 @@ REG_BPROP_BUILDER("BCEWithLogitsLoss").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib
grad_target = ib->Mul(grad_target, weight);
if (reduction == "mean") {
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(dx), ib->GetDtype(dx)));
grad_target = ib->RealDiv(grad_target, ib->Tensor(ib->GetSize(target), ib->GetDtype(grad_target)));
if (IsDynamic(ib->GetShape(dx))) {
auto res = ib->DynSize(dx, ib->GetDtype(dx));
dx = ib->RealDiv(dx, res);
auto res2 = ib->DynSize(target, ib->GetDtype(grad_target));
grad_target = ib->RealDiv(grad_target, res2);
} else {
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(dx), ib->GetDtype(dx)));
grad_target = ib->RealDiv(grad_target, ib->Tensor(ib->GetSize(target), ib->GetDtype(grad_target)));
}
}
return {dx, grad_target, ib->ZerosLike(weight), ib->ZerosLike(pos_weight)};
});
@ -1291,7 +1346,12 @@ REG_BPROP_BUILDER("KLDivLoss").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
NodePtr dx;
if (reduction == "mean") {
dx = ib->Emit("KLDivLossGrad", {dout, x, y}, {{"reduction", MakeValue("sum")}});
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(x), ib->GetDtype(dx)));
if (IsDynamic(ib->GetShape(x))) {
auto res = ib->DynSize(dx, ib->GetDtype(dx));
dx = ib->RealDiv(dx, res);
} else {
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(x), ib->GetDtype(dx)));
}
} else {
dx = ib->Emit("KLDivLossGrad", {dout, x, y}, {{"reduction", MakeValue(reduction)}});
}
@ -1499,9 +1559,7 @@ REG_BPROP_BUILDER("NthElement").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto shp = ib->GetShape(x);
auto x_shape =
IsDynamic(shp) ? ib->Emit("TupleToTensor", {ib->Emit("Shape", {x}), ib->Value(kInt64)}) : ib->Tensor(shp);
auto x_shape = ib->Shape(x, true);
auto dx = ib->Emit("AdaptiveAvgPool3DGrad", {dout, ib->Cast(x_shape, kInt32)});
return {dx};
});
@ -1574,7 +1632,7 @@ REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib)
REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto orig_input_shape = ib->Value<ShapeVector>(ib->GetShape(x));
auto orig_input_shape = ib->Shape(x);
auto dx = ib->Emit("AvgPoolGradV1", {orig_input_shape, dout},
{
{"kernel_size", ib->GetAttr("kernel_size")},