From 2049e36c8e554dcbfe5887bd6bc8fd180b7a6999 Mon Sep 17 00:00:00 2001 From: r1chardf1d0 Date: Sun, 5 Mar 2023 11:37:45 +0800 Subject: [PATCH] add some nn expander bprop --- .../bprop_expander_meta_func_graph.cc | 9 +- .../grad/bprop_expander/bprop_irbuilder.h | 28 +++ .../bprop_expander/grad_ops/grad_nn_ops.cc | 216 +++++++++++------- 3 files changed, 173 insertions(+), 80 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.cc b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.cc index 5360614f538..5cd9b9b8c53 100644 --- a/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.cc +++ b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.cc @@ -262,6 +262,11 @@ void RegNNBpropExpanderOps2() { REGISTER_EXPANDER_BPROP_IMPL(AdaptiveAvgPool3D); REGISTER_EXPANDER_BPROP_IMPL(FractionalAvgPool); REGISTER_EXPANDER_BPROP_IMPL(PSROIPooling); + REGISTER_EXPANDER_BPROP_IMPL(BiasAddGrad); + REGISTER_EXPANDER_BPROP_IMPL(MaxPoolGrad); + REGISTER_EXPANDER_BPROP_IMPL(TopK); + REGISTER_EXPANDER_BPROP_IMPL(BCEWithLogitsLoss); + REGISTER_EXPANDER_BPROP_IMPL(KLDivLoss); } void RegArrayBpropExpanderOps1() { @@ -321,7 +326,9 @@ void RegArrayBpropExpanderOps2() {} void RegClipBpropExpanderOps() {} void RegCommBpropExpanderOps() {} void RegInnerBpropExpanderOps() {} -void RegOtherBpropExpanderOps() {} + +void RegOtherBpropExpanderOps() { REGISTER_EXPANDER_BPROP_IMPL(Assign); } + void RegQuantBpropExpanderOps() {} void RegSparseBpropExpanderOps() {} diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h index fd15a129ab0..ed2c25a6ded 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "include/common/expander/core/node.h" @@ -56,12 +57,39 @@ class BpropIRBuilder : public Emitter { // For node that has single output ShapeVector GetShape(const NodePtr &node) const { return node->shape(); } + NodePtr Shape(const NodePtr &node, bool tensor = false) const { + auto shape = GetShape(node); + if (tensor) { + return IsDynamic(shape) ? Emit("TensorShape", {node}) : Tensor(shape); + } else { + return IsDynamic(shape) ? Emit("Shape", {node}) : Value(shape); + } + } + // For node that has multiple outputs std::vector GetShapes(const NodePtr &node) const { return node->shapes(); } TypePtr GetDtype(const NodePtr &node) const { return node->dtype(); } TypeId GetDtypeId(const NodePtr &node) const { return GetDtype(node)->type_id(); } ValuePtr GetAttr(const NodePtr &node, const std::string &attr) const; int64_t GetSize(const NodePtr &node) const; + NodePtr DynSize(const NodePtr &node, const TypePtr &type) const { return Cast(DynSize(node), type); } + NodePtr DynSize(const NodePtr &node, TypeId type_id) const { return Cast(DynSize(node), type_id); } + NodePtr DynSize(const NodePtr &node) const { + auto shape_func = [](const ShapeArray &inputs) -> ShapeArray { + auto shape = inputs.at(0); + int64_t size = 1; + for (auto &i : shape) { + size *= i; + } + return {{size}}; + }; + auto infer_func = [](const ShapeArray &inputs, const std::unordered_set &) -> ShapeVector { return {1}; }; + return ShapeCalc({node}, shape_func, infer_func, {})[0]; + } + NodePtr Range(const NodePtr &limit) const { return Range(Tensor(0, kInt64), limit, Tensor(1, kInt64)); } + NodePtr Range(const NodePtr &start, const NodePtr &limit, const NodePtr &delta, int64_t max_len = 1000000) const { + return Emit("Range", {start, limit, delta}, {{"maxlen", MakeValue(max_len)}}); + } std::string name() const { return name_; } std::string GetTargetFromContext() const; diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc index 97b5be79925..a129523e307 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc @@ -71,8 +71,7 @@ NodePtrList Conv2DTransposeBpropExpander(const BpropIRBuilder *ib) { auto w = ib->GetInput(kIndex1); auto f_sizes = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex4); - auto shp = ib->GetShape(w); - auto w_shape = IsDynamic(shp) ? ib->Emit("Shape", {w}) : ib->Value(shp); + auto w_shape = ib->Shape(w); auto dx = ib->Emit(kConv2DOpName, {dout, w}, {{"pad_mode", ib->GetAttr("pad_mode")}, {"pad", ib->GetAttr("pad")}, @@ -196,33 +195,51 @@ REG_BPROP_BUILDER("TopK").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto input_x = ib->GetInput(kIndex0); auto out = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex3); - auto indices = ib->TupleGetItem(out, kIndex1); auto dout0 = ib->TupleGetItem(dout, kIndex0); - auto in_shape = ib->GetShape(input_x); - auto in_lastdim = in_shape.back(); - - auto ind_shape = ib->GetShape(indices); - auto ind_lastdim = ind_shape.back(); - - auto ind_2d = ib->Reshape(indices, {-1, ind_lastdim}); - auto outerdim = ib->GetShape(ind_2d)[0]; // k - - // [0, outerdim, 2*outerdim, ..., (k-1)*outerdim] - auto indices_dtype = ib->GetDtype(indices); - std::vector range_flatten_index_vec(LongToSize(outerdim)); - for (int64_t i = 0; i < outerdim; i++) { - range_flatten_index_vec[i] = i * in_lastdim; + if (IsDynamic(in_shape)) { + auto shape_func0 = [](const ShapeArray &inputs) -> ShapeArray { return {{-1, inputs.at(0).back()}}; }; + auto infer_func0 = [](const ShapeArray &inputs, const std::unordered_set &) -> ShapeVector { return {2}; }; + auto re0 = ib->ShapeCalc({indices}, shape_func0, infer_func0, {})[0]; + NodePtr ind_2d = ib->Reshape(indices, re0); + auto shape_func = [](const ShapeArray &inputs) -> ShapeArray { + auto in_shape = inputs.at(0); + auto in_lastdim = in_shape.back(); + auto outerdim = inputs.at(1)[0]; // k + auto in_shape_1d_x = + ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies())); + return {in_shape_1d_x, {outerdim * in_lastdim}, {in_lastdim}}; + }; + auto infer_func = [](const ShapeArray &inputs, const std::unordered_set &) -> ShapeVector { + return {1, 1, 1}; + }; + auto res = ib->ShapeCalc({input_x, ind_2d}, shape_func, infer_func, {}); + auto in_shape_1d = res[0]; + auto range_flatten_index = ib->Range(ib->Tensor(0, kInt64), res[1], res[2]); + auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1}); + auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), in_shape_1d}); + out_grad = ib->Reshape(out_grad, ib->Emit("TensorShape", {input_x})); + auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1)); + return {out_grad, grad_k}; + } else { + auto ind_lastdim = ib->GetShape(indices).back(); + auto ind_2d = ib->Reshape(indices, {-1, ind_lastdim}); + auto in_lastdim = in_shape.back(); + auto outerdim = ib->GetShape(ind_2d)[0]; // k + std::vector range_flatten_index_vec(LongToSize(outerdim)); + for (int64_t i = 0; i < outerdim; i++) { + range_flatten_index_vec[i] = i * in_lastdim; + } + auto range_flatten_index = ib->Tensor(range_flatten_index_vec, ib->GetDtype(indices)); + auto in_shape_1d = + ib->Value(ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies()))); + auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1}); + auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), in_shape_1d}); + out_grad = ib->Reshape(out_grad, in_shape); + auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1)); + return {out_grad, grad_k}; } - auto range_flatten_index = ib->Tensor(range_flatten_index_vec, indices_dtype); - auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1}); - auto in_shape_1d = ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies())); - auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), ib->Value(in_shape_1d)}); - out_grad = ib->Reshape(out_grad, in_shape); - - auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1)); - return {out_grad, grad_k}; }); REG_BPROP_BUILDER("PReLU").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { @@ -251,8 +268,7 @@ REG_BPROP_BUILDER("Pad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { for (const auto &item : paddings) { begin.push_back(item.at(0)); } - auto shp = ib->GetShape(x); - auto x_shape = IsDynamic(shp) ? ib->Emit("Shape", {x}) : ib->Value(shp); + auto x_shape = ib->Shape(x); auto dx = ib->Emit("Slice", {dout, ib->EmitValue(MakeValue(begin)), x_shape}); return {dx}; }); @@ -262,7 +278,7 @@ REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { auto rois = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); auto shp = ib->GetShape(inputs); - auto inputs_shape = IsDynamic(shp) ? ib->Emit("Shape", {inputs}) : ib->Value(shp); + auto inputs_shape = ib->Shape(inputs); auto dx = ib->Emit("ROIAlignGrad", {dout, rois, inputs_shape}, {{"pooled_height", ib->GetAttr("pooled_height")}, {"pooled_width", ib->GetAttr("pooled_width")}, @@ -543,10 +559,8 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto w = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - auto x_sh = ib->GetShape(x); - auto w_sh = ib->GetShape(w); - auto x_shape = IsDynamic(x_sh) ? ib->Emit("Shape", {x}) : ib->Value(x_sh); - auto w_shape = IsDynamic(w_sh) ? ib->Emit("Shape", {w}) : ib->Value(w_sh); + auto x_shape = ib->Shape(x); + auto w_shape = ib->Shape(w); auto dx = ib->Emit("Conv3DBackpropInput", {w, dout, x_shape}, {{"pad_mode", ib->GetAttr("pad_mode")}, {"pad", ib->GetAttr("pad")}, @@ -559,7 +573,7 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { {"format", ib->GetAttr("format")}, {"out_channel", ib->GetAttr("out_channel")}, {"kernel_size", ib->GetAttr("kernel_size")}, - {"input_size", MakeValue(x_sh)}, + {"input_size", MakeValue(ib->GetShape(x))}, {"mode", ib->GetAttr("mode")}}); auto dw = ib->Emit("Conv3DBackpropFilter", {x, dout, w_shape}, {{"pad_mode", ib->GetAttr("pad_mode")}, @@ -573,7 +587,7 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { {"format", ib->GetAttr("format")}, {"out_channel", ib->GetAttr("out_channel")}, {"kernel_size", ib->GetAttr("kernel_size")}, - {"filter_size", MakeValue(w_sh)}, + {"filter_size", MakeValue(ib->GetShape(w))}, {"mode", ib->GetAttr("mode")}}); dw = ib->Cast(dw, ib->GetDtype(x)); return {dx, dw}; @@ -587,8 +601,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) auto x = ib->GetInput(kIndex0); auto w = ib->GetInput(kIndex1); auto dout = ib->GetInput(kIndex3); - auto w_sh = ib->GetShape(w); - auto w_shape = IsDynamic(w_sh) ? ib->Emit("Shape", {w}) : ib->Value(w_sh); + auto w_shape = ib->Shape(w); auto dx = ib->Emit("Conv3D", {dout, w}, {{"out_channel", ib->GetAttr("in_channel")}, {"kernel_size", ib->GetAttr("kernel_size")}, @@ -607,7 +620,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, w_shape}, {{"out_channel", ib->GetAttr("in_channel")}, {"kernel_size", ib->GetAttr("kernel_size")}, - {"filter_size", MakeValue(w_sh)}, + {"filter_size", MakeValue(ib->GetShape(w))}, {"mode", ib->GetAttr("mode")}, {"pad_mode", MakeValue("pad")}, {"pad", ib->GetAttr("pad_list")}, @@ -683,11 +696,6 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {"format", MakeValue("NCHW")}, {"pad_mode", ib->GetAttr("pad_mode")}}); } else { - auto x2_shape = ib->GetShape(x2); - auto b = x2_shape.at(0); - auto c = x2_shape.at(1); - auto h = x2_shape.at(2); - auto w = x2_shape.at(3); auto tmp = ib->Emit("MaxPoolWithArgmax", {x1}, {{"kernel_size", MakeValue(kernel_size)}, {"strides", MakeValue(strides)}, @@ -695,11 +703,39 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {"data_format", MakeValue("NCHW")}, {"format", MakeValue("NCHW")}}); auto ind = ib->TupleGetItem(tmp, 1); - auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32)); - batch = ib->Tile(ib->Reshape(batch, {-1, 1}), {1, (c * h) * w}); - auto gather_ind = - ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, {b, -1})})}, {{"axis", MakeValue(-1)}}); - dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, {b, -1}), gather_ind}), {b, c, h, w}); + auto x2_shape = ib->GetShape(x2); + if (IsDynamic(x2_shape)) { + auto shape = ib->Emit("Shape", {x2}); + auto shape_func = [](const ShapeArray &inputs) -> ShapeArray { + auto x2_shape = inputs.at(0); + auto b = x2_shape.at(0); + auto c = x2_shape.at(1); + auto h = x2_shape.at(2); + auto w = x2_shape.at(3); + return {{b}, {b, -1}, {1, c * h * w}}; + }; + + auto infer_func = [](const ShapeArray &inputs, const std::unordered_set &) -> ShapeVector { + return {1, 2, 2}; + }; + + auto res = ib->ShapeCalc({x2}, shape_func, infer_func, {}); + auto batch = ib->Cast(ib->Range(res[0]), kInt32); + batch = ib->Tile(ib->Reshape(batch, {-1, 1}), res[2]); + auto gather_ind = + ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, res[1])})}, {{"axis", MakeValue(-1)}}); + dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, res[1]), gather_ind}), shape); + } else { + auto b = x2_shape.at(0); + auto c = x2_shape.at(1); + auto h = x2_shape.at(2); + auto w = x2_shape.at(3); + auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32)); + batch = ib->Tile(ib->Reshape(batch, {-1, 1}), {1, (c * h) * w}); + auto gather_ind = + ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, {b, -1})})}, {{"axis", MakeValue(-1)}}); + dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, {b, -1}), gather_ind}), {b, c, h, w}); + } } return {dx1, dx2, dgrad}; }); @@ -801,11 +837,10 @@ REG_BPROP_BUILDER("AvgPool").SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); - auto x_sh = ib->GetShape(x); - auto x_shape = IsDynamic(x_sh) ? ib->Emit("Shape", {x}) : ib->Value(x_sh); + auto x_shape = ib->Shape(x); auto dx = ib->Emit("AvgPool3DGrad", {x_shape, dout}, {{"kernel_size", ib->GetAttr("kernel_size")}, - {"origin_input_shape", MakeValue(x_sh)}, + {"origin_input_shape", MakeValue(ib->GetShape(x))}, {"strides", ib->GetAttr("strides")}, {"pad_list", ib->GetAttr("pad_list")}, {"ceil_mode", ib->GetAttr("ceil_mode")}, @@ -853,25 +888,40 @@ REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) auto data_format = GetValue(ib->GetAttr("format")); auto dy = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); - auto dy_shape = ib->GetShape(dy); - auto dout_shape = ib->GetShape(dout); - ShapeVector expanded_shape; - ShapeVector tile_mults; - ShapeVector one_vec{1}; - if (data_format == "NCHW") { - // expanded_shape = np.concatenate([np.ones_like(shape[:1]), bias_shape, np.ones_like(shape[2:])], axis=0) - expanded_shape = one_vec + dout_shape; - expanded_shape = dy_shape.size() > 2 ? expanded_shape + ShapeVector(1, dy_shape.size() - 2) : expanded_shape; - // tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0) - ShapeVector tmp{dy_shape[0], 1}; - tile_mults = tmp; - tile_mults = dy_shape.size() > 2 ? tile_mults + ShapeVector(dy_shape.begin() + 2, dy_shape.end()) : tile_mults; - } else { - // expanded_shape = np.concatenate([np.ones_like(shape[:-1]), bias_shape], axis=0) - expanded_shape = ShapeVector(1, dy_shape.size() - 1) + dout_shape; - // tile_mults = np.concatenate([shape[:-1], [1]], axis=0) - tile_mults = ShapeVector(dy_shape.begin(), dy_shape.end() - 1) + one_vec; - } + auto shape_func = [data_format](const ShapeArray &inputs) -> ShapeArray { + ShapeVector expanded_shape; + ShapeVector tile_mults; + ShapeVector one_vec{1}; + auto dy_shape = inputs.at(0); + auto dout_shape = inputs.at(1); + if (data_format == "NCHW") { + // expanded_shape = np.concatenate([np.ones_like(shape[:1]), bias_shape, np.ones_like(shape[2:])], axis=0) + expanded_shape = one_vec + dout_shape; + expanded_shape = dy_shape.size() > 2 ? expanded_shape + ShapeVector(1, dy_shape.size() - 2) : expanded_shape; + // tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0) + ShapeVector tmp{dy_shape[0], 1}; + tile_mults = tmp; + tile_mults = dy_shape.size() > 2 ? tile_mults + ShapeVector(dy_shape.begin() + 2, dy_shape.end()) : tile_mults; + } else { + // expanded_shape = np.concatenate([np.ones_like(shape[:-1]), bias_shape], axis=0) + expanded_shape = ShapeVector(1, dy_shape.size() - 1) + dout_shape; + // tile_mults = np.concatenate([shape[:-1], [1]], axis=0) + tile_mults = ShapeVector(dy_shape.begin(), dy_shape.end() - 1) + one_vec; + } + return {expanded_shape, tile_mults}; + }; + + auto infer_func = [](const ShapeArray &inputs, const std::unordered_set &) -> ShapeVector { + int64_t x_rank = + IsDynamicRank(inputs.at(0)) ? -1 : static_cast(inputs.at(0).size() + inputs.at(1).size() - 1); + int64_t y_rank = IsDynamicRank(inputs.at(1)) ? -1 : static_cast(inputs.at(0).size()); + return {x_rank, y_rank}; + }; + + auto res = ib->ShapeCalc({dy, dout}, shape_func, infer_func, {}); + NodePtr expanded_shape = res[0]; + NodePtr tile_mults = res[1]; + auto expanded_grad = ib->Reshape(dout, expanded_shape); auto tiled_grad = ib->Tile(expanded_grad, tile_mults); return {tiled_grad}; @@ -1103,7 +1153,6 @@ REG_BPROP_BUILDER("Softmax").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) { NodePtr reverse_axis = IsDynamicRank(shp) ? ib->ShapeCalc({x}, shape_func, infer_func)[0] : ib->Value(GetTransposeAxis(shp, one_axis)); - out = ib->Transpose(out, reverse_axis); dout = ib->Transpose(dout, reverse_axis); auto dx = ib->Mul(out, ib->Sub(dout, ib->ReduceSum(ib->Mul(out, dout), ShapeVector{-1}, true))); @@ -1227,8 +1276,7 @@ REG_BPROP_BUILDER("Conv2DBackpropFilter").SetUnusedInputs({i2, i3}).SetBody(BODY auto x = ib->GetInput(kIndex1); auto filter_size = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex4); - auto shp = ib->GetShape(x); - auto x_shape = IsDynamic(shp) ? ib->Emit("Shape", {x}) : ib->Value(shp); + auto x_shape = ib->Shape(x); auto dw_dx = ib->Emit(kConv2DBackpropInputOpName, {dy, dout, x_shape}, {{"mode", ib->GetAttr("mode")}, {"dilation", ib->GetAttr("dilation")}, @@ -1277,8 +1325,15 @@ REG_BPROP_BUILDER("BCEWithLogitsLoss").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib grad_target = ib->Mul(grad_target, weight); if (reduction == "mean") { - dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(dx), ib->GetDtype(dx))); - grad_target = ib->RealDiv(grad_target, ib->Tensor(ib->GetSize(target), ib->GetDtype(grad_target))); + if (IsDynamic(ib->GetShape(dx))) { + auto res = ib->DynSize(dx, ib->GetDtype(dx)); + dx = ib->RealDiv(dx, res); + auto res2 = ib->DynSize(target, ib->GetDtype(grad_target)); + grad_target = ib->RealDiv(grad_target, res2); + } else { + dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(dx), ib->GetDtype(dx))); + grad_target = ib->RealDiv(grad_target, ib->Tensor(ib->GetSize(target), ib->GetDtype(grad_target))); + } } return {dx, grad_target, ib->ZerosLike(weight), ib->ZerosLike(pos_weight)}; }); @@ -1291,7 +1346,12 @@ REG_BPROP_BUILDER("KLDivLoss").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) { NodePtr dx; if (reduction == "mean") { dx = ib->Emit("KLDivLossGrad", {dout, x, y}, {{"reduction", MakeValue("sum")}}); - dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(x), ib->GetDtype(dx))); + if (IsDynamic(ib->GetShape(x))) { + auto res = ib->DynSize(dx, ib->GetDtype(dx)); + dx = ib->RealDiv(dx, res); + } else { + dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(x), ib->GetDtype(dx))); + } } else { dx = ib->Emit("KLDivLossGrad", {dout, x, y}, {{"reduction", MakeValue(reduction)}}); } @@ -1499,9 +1559,7 @@ REG_BPROP_BUILDER("NthElement").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); - auto shp = ib->GetShape(x); - auto x_shape = - IsDynamic(shp) ? ib->Emit("TupleToTensor", {ib->Emit("Shape", {x}), ib->Value(kInt64)}) : ib->Tensor(shp); + auto x_shape = ib->Shape(x, true); auto dx = ib->Emit("AdaptiveAvgPool3DGrad", {dout, ib->Cast(x_shape, kInt32)}); return {dx}; }); @@ -1574,7 +1632,7 @@ REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { auto x = ib->GetInput(kIndex0); auto dout = ib->GetInput(kIndex2); - auto orig_input_shape = ib->Value(ib->GetShape(x)); + auto orig_input_shape = ib->Shape(x); auto dx = ib->Emit("AvgPoolGradV1", {orig_input_shape, dout}, { {"kernel_size", ib->GetAttr("kernel_size")},