forked from mindspore-Ecosystem/mindspore
add some nn expander bprop
This commit is contained in:
parent
d1a4199d05
commit
2049e36c8e
|
@ -262,6 +262,11 @@ void RegNNBpropExpanderOps2() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(AdaptiveAvgPool3D);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(FractionalAvgPool);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(PSROIPooling);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(BiasAddGrad);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(MaxPoolGrad);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TopK);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(BCEWithLogitsLoss);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(KLDivLoss);
|
||||
}
|
||||
|
||||
void RegArrayBpropExpanderOps1() {
|
||||
|
@ -321,7 +326,9 @@ void RegArrayBpropExpanderOps2() {}
|
|||
void RegClipBpropExpanderOps() {}
|
||||
void RegCommBpropExpanderOps() {}
|
||||
void RegInnerBpropExpanderOps() {}
|
||||
void RegOtherBpropExpanderOps() {}
|
||||
|
||||
void RegOtherBpropExpanderOps() { REGISTER_EXPANDER_BPROP_IMPL(Assign); }
|
||||
|
||||
void RegQuantBpropExpanderOps() {}
|
||||
void RegSparseBpropExpanderOps() {}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
|
||||
#include "include/common/expander/core/node.h"
|
||||
|
@ -56,12 +57,39 @@ class BpropIRBuilder : public Emitter {
|
|||
|
||||
// For node that has single output
|
||||
ShapeVector GetShape(const NodePtr &node) const { return node->shape(); }
|
||||
NodePtr Shape(const NodePtr &node, bool tensor = false) const {
|
||||
auto shape = GetShape(node);
|
||||
if (tensor) {
|
||||
return IsDynamic(shape) ? Emit("TensorShape", {node}) : Tensor(shape);
|
||||
} else {
|
||||
return IsDynamic(shape) ? Emit("Shape", {node}) : Value<ShapeVector>(shape);
|
||||
}
|
||||
}
|
||||
|
||||
// For node that has multiple outputs
|
||||
std::vector<ShapeVector> GetShapes(const NodePtr &node) const { return node->shapes(); }
|
||||
TypePtr GetDtype(const NodePtr &node) const { return node->dtype(); }
|
||||
TypeId GetDtypeId(const NodePtr &node) const { return GetDtype(node)->type_id(); }
|
||||
ValuePtr GetAttr(const NodePtr &node, const std::string &attr) const;
|
||||
int64_t GetSize(const NodePtr &node) const;
|
||||
NodePtr DynSize(const NodePtr &node, const TypePtr &type) const { return Cast(DynSize(node), type); }
|
||||
NodePtr DynSize(const NodePtr &node, TypeId type_id) const { return Cast(DynSize(node), type_id); }
|
||||
NodePtr DynSize(const NodePtr &node) const {
|
||||
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
|
||||
auto shape = inputs.at(0);
|
||||
int64_t size = 1;
|
||||
for (auto &i : shape) {
|
||||
size *= i;
|
||||
}
|
||||
return {{size}};
|
||||
};
|
||||
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector { return {1}; };
|
||||
return ShapeCalc({node}, shape_func, infer_func, {})[0];
|
||||
}
|
||||
NodePtr Range(const NodePtr &limit) const { return Range(Tensor(0, kInt64), limit, Tensor(1, kInt64)); }
|
||||
NodePtr Range(const NodePtr &start, const NodePtr &limit, const NodePtr &delta, int64_t max_len = 1000000) const {
|
||||
return Emit("Range", {start, limit, delta}, {{"maxlen", MakeValue(max_len)}});
|
||||
}
|
||||
|
||||
std::string name() const { return name_; }
|
||||
std::string GetTargetFromContext() const;
|
||||
|
|
|
@ -71,8 +71,7 @@ NodePtrList Conv2DTransposeBpropExpander(const BpropIRBuilder *ib) {
|
|||
auto w = ib->GetInput(kIndex1);
|
||||
auto f_sizes = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto shp = ib->GetShape(w);
|
||||
auto w_shape = IsDynamic(shp) ? ib->Emit("Shape", {w}) : ib->Value<ShapeVector>(shp);
|
||||
auto w_shape = ib->Shape(w);
|
||||
auto dx = ib->Emit(kConv2DOpName, {dout, w},
|
||||
{{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
{"pad", ib->GetAttr("pad")},
|
||||
|
@ -196,33 +195,51 @@ REG_BPROP_BUILDER("TopK").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
||||
auto indices = ib->TupleGetItem(out, kIndex1);
|
||||
auto dout0 = ib->TupleGetItem(dout, kIndex0);
|
||||
|
||||
auto in_shape = ib->GetShape(input_x);
|
||||
auto in_lastdim = in_shape.back();
|
||||
|
||||
auto ind_shape = ib->GetShape(indices);
|
||||
auto ind_lastdim = ind_shape.back();
|
||||
|
||||
auto ind_2d = ib->Reshape(indices, {-1, ind_lastdim});
|
||||
auto outerdim = ib->GetShape(ind_2d)[0]; // k
|
||||
|
||||
// [0, outerdim, 2*outerdim, ..., (k-1)*outerdim]
|
||||
auto indices_dtype = ib->GetDtype(indices);
|
||||
std::vector<int64_t> range_flatten_index_vec(LongToSize(outerdim));
|
||||
for (int64_t i = 0; i < outerdim; i++) {
|
||||
range_flatten_index_vec[i] = i * in_lastdim;
|
||||
if (IsDynamic(in_shape)) {
|
||||
auto shape_func0 = [](const ShapeArray &inputs) -> ShapeArray { return {{-1, inputs.at(0).back()}}; };
|
||||
auto infer_func0 = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector { return {2}; };
|
||||
auto re0 = ib->ShapeCalc({indices}, shape_func0, infer_func0, {})[0];
|
||||
NodePtr ind_2d = ib->Reshape(indices, re0);
|
||||
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
|
||||
auto in_shape = inputs.at(0);
|
||||
auto in_lastdim = in_shape.back();
|
||||
auto outerdim = inputs.at(1)[0]; // k
|
||||
auto in_shape_1d_x =
|
||||
ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies<int64_t>()));
|
||||
return {in_shape_1d_x, {outerdim * in_lastdim}, {in_lastdim}};
|
||||
};
|
||||
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector {
|
||||
return {1, 1, 1};
|
||||
};
|
||||
auto res = ib->ShapeCalc({input_x, ind_2d}, shape_func, infer_func, {});
|
||||
auto in_shape_1d = res[0];
|
||||
auto range_flatten_index = ib->Range(ib->Tensor(0, kInt64), res[1], res[2]);
|
||||
auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
|
||||
auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), in_shape_1d});
|
||||
out_grad = ib->Reshape(out_grad, ib->Emit("TensorShape", {input_x}));
|
||||
auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1));
|
||||
return {out_grad, grad_k};
|
||||
} else {
|
||||
auto ind_lastdim = ib->GetShape(indices).back();
|
||||
auto ind_2d = ib->Reshape(indices, {-1, ind_lastdim});
|
||||
auto in_lastdim = in_shape.back();
|
||||
auto outerdim = ib->GetShape(ind_2d)[0]; // k
|
||||
std::vector<int64_t> range_flatten_index_vec(LongToSize(outerdim));
|
||||
for (int64_t i = 0; i < outerdim; i++) {
|
||||
range_flatten_index_vec[i] = i * in_lastdim;
|
||||
}
|
||||
auto range_flatten_index = ib->Tensor(range_flatten_index_vec, ib->GetDtype(indices));
|
||||
auto in_shape_1d =
|
||||
ib->Value(ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies<int64_t>())));
|
||||
auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
|
||||
auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), in_shape_1d});
|
||||
out_grad = ib->Reshape(out_grad, in_shape);
|
||||
auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1));
|
||||
return {out_grad, grad_k};
|
||||
}
|
||||
auto range_flatten_index = ib->Tensor(range_flatten_index_vec, indices_dtype);
|
||||
auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
|
||||
auto in_shape_1d = ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies<int64_t>()));
|
||||
auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), ib->Value(in_shape_1d)});
|
||||
out_grad = ib->Reshape(out_grad, in_shape);
|
||||
|
||||
auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1));
|
||||
return {out_grad, grad_k};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("PReLU").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -251,8 +268,7 @@ REG_BPROP_BUILDER("Pad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
for (const auto &item : paddings) {
|
||||
begin.push_back(item.at(0));
|
||||
}
|
||||
auto shp = ib->GetShape(x);
|
||||
auto x_shape = IsDynamic(shp) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(shp);
|
||||
auto x_shape = ib->Shape(x);
|
||||
auto dx = ib->Emit("Slice", {dout, ib->EmitValue(MakeValue(begin)), x_shape});
|
||||
return {dx};
|
||||
});
|
||||
|
@ -262,7 +278,7 @@ REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
|||
auto rois = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto shp = ib->GetShape(inputs);
|
||||
auto inputs_shape = IsDynamic(shp) ? ib->Emit("Shape", {inputs}) : ib->Value<ShapeVector>(shp);
|
||||
auto inputs_shape = ib->Shape(inputs);
|
||||
auto dx = ib->Emit("ROIAlignGrad", {dout, rois, inputs_shape},
|
||||
{{"pooled_height", ib->GetAttr("pooled_height")},
|
||||
{"pooled_width", ib->GetAttr("pooled_width")},
|
||||
|
@ -543,10 +559,8 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto x_sh = ib->GetShape(x);
|
||||
auto w_sh = ib->GetShape(w);
|
||||
auto x_shape = IsDynamic(x_sh) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(x_sh);
|
||||
auto w_shape = IsDynamic(w_sh) ? ib->Emit("Shape", {w}) : ib->Value<ShapeVector>(w_sh);
|
||||
auto x_shape = ib->Shape(x);
|
||||
auto w_shape = ib->Shape(w);
|
||||
auto dx = ib->Emit("Conv3DBackpropInput", {w, dout, x_shape},
|
||||
{{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
{"pad", ib->GetAttr("pad")},
|
||||
|
@ -559,7 +573,7 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
{"format", ib->GetAttr("format")},
|
||||
{"out_channel", ib->GetAttr("out_channel")},
|
||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"input_size", MakeValue(x_sh)},
|
||||
{"input_size", MakeValue(ib->GetShape(x))},
|
||||
{"mode", ib->GetAttr("mode")}});
|
||||
auto dw = ib->Emit("Conv3DBackpropFilter", {x, dout, w_shape},
|
||||
{{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
|
@ -573,7 +587,7 @@ REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
{"format", ib->GetAttr("format")},
|
||||
{"out_channel", ib->GetAttr("out_channel")},
|
||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"filter_size", MakeValue(w_sh)},
|
||||
{"filter_size", MakeValue(ib->GetShape(w))},
|
||||
{"mode", ib->GetAttr("mode")}});
|
||||
dw = ib->Cast(dw, ib->GetDtype(x));
|
||||
return {dx, dw};
|
||||
|
@ -587,8 +601,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib)
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto w_sh = ib->GetShape(w);
|
||||
auto w_shape = IsDynamic(w_sh) ? ib->Emit("Shape", {w}) : ib->Value<ShapeVector>(w_sh);
|
||||
auto w_shape = ib->Shape(w);
|
||||
auto dx = ib->Emit("Conv3D", {dout, w},
|
||||
{{"out_channel", ib->GetAttr("in_channel")},
|
||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
|
@ -607,7 +620,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib)
|
|||
auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, w_shape},
|
||||
{{"out_channel", ib->GetAttr("in_channel")},
|
||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"filter_size", MakeValue(w_sh)},
|
||||
{"filter_size", MakeValue(ib->GetShape(w))},
|
||||
{"mode", ib->GetAttr("mode")},
|
||||
{"pad_mode", MakeValue("pad")},
|
||||
{"pad", ib->GetAttr("pad_list")},
|
||||
|
@ -683,11 +696,6 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib)
|
|||
{"format", MakeValue("NCHW")},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")}});
|
||||
} else {
|
||||
auto x2_shape = ib->GetShape(x2);
|
||||
auto b = x2_shape.at(0);
|
||||
auto c = x2_shape.at(1);
|
||||
auto h = x2_shape.at(2);
|
||||
auto w = x2_shape.at(3);
|
||||
auto tmp = ib->Emit("MaxPoolWithArgmax", {x1},
|
||||
{{"kernel_size", MakeValue(kernel_size)},
|
||||
{"strides", MakeValue(strides)},
|
||||
|
@ -695,11 +703,39 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib)
|
|||
{"data_format", MakeValue("NCHW")},
|
||||
{"format", MakeValue("NCHW")}});
|
||||
auto ind = ib->TupleGetItem(tmp, 1);
|
||||
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
|
||||
batch = ib->Tile(ib->Reshape(batch, {-1, 1}), {1, (c * h) * w});
|
||||
auto gather_ind =
|
||||
ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, {b, -1})})}, {{"axis", MakeValue<int64_t>(-1)}});
|
||||
dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, {b, -1}), gather_ind}), {b, c, h, w});
|
||||
auto x2_shape = ib->GetShape(x2);
|
||||
if (IsDynamic(x2_shape)) {
|
||||
auto shape = ib->Emit("Shape", {x2});
|
||||
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
|
||||
auto x2_shape = inputs.at(0);
|
||||
auto b = x2_shape.at(0);
|
||||
auto c = x2_shape.at(1);
|
||||
auto h = x2_shape.at(2);
|
||||
auto w = x2_shape.at(3);
|
||||
return {{b}, {b, -1}, {1, c * h * w}};
|
||||
};
|
||||
|
||||
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector {
|
||||
return {1, 2, 2};
|
||||
};
|
||||
|
||||
auto res = ib->ShapeCalc({x2}, shape_func, infer_func, {});
|
||||
auto batch = ib->Cast(ib->Range(res[0]), kInt32);
|
||||
batch = ib->Tile(ib->Reshape(batch, {-1, 1}), res[2]);
|
||||
auto gather_ind =
|
||||
ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, res[1])})}, {{"axis", MakeValue<int64_t>(-1)}});
|
||||
dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, res[1]), gather_ind}), shape);
|
||||
} else {
|
||||
auto b = x2_shape.at(0);
|
||||
auto c = x2_shape.at(1);
|
||||
auto h = x2_shape.at(2);
|
||||
auto w = x2_shape.at(3);
|
||||
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
|
||||
batch = ib->Tile(ib->Reshape(batch, {-1, 1}), {1, (c * h) * w});
|
||||
auto gather_ind =
|
||||
ib->Emit("Stack", {ib->MakeTuple({batch, ib->Reshape(ind, {b, -1})})}, {{"axis", MakeValue<int64_t>(-1)}});
|
||||
dgrad = ib->Reshape(ib->Emit("GatherNd", {ib->Reshape(dout, {b, -1}), gather_ind}), {b, c, h, w});
|
||||
}
|
||||
}
|
||||
return {dx1, dx2, dgrad};
|
||||
});
|
||||
|
@ -801,11 +837,10 @@ REG_BPROP_BUILDER("AvgPool").SetBody(BODYFUNC(ib) {
|
|||
REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_sh = ib->GetShape(x);
|
||||
auto x_shape = IsDynamic(x_sh) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(x_sh);
|
||||
auto x_shape = ib->Shape(x);
|
||||
auto dx = ib->Emit("AvgPool3DGrad", {x_shape, dout},
|
||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"origin_input_shape", MakeValue(x_sh)},
|
||||
{"origin_input_shape", MakeValue(ib->GetShape(x))},
|
||||
{"strides", ib->GetAttr("strides")},
|
||||
{"pad_list", ib->GetAttr("pad_list")},
|
||||
{"ceil_mode", ib->GetAttr("ceil_mode")},
|
||||
|
@ -853,25 +888,40 @@ REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib)
|
|||
auto data_format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
auto dy = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dy_shape = ib->GetShape(dy);
|
||||
auto dout_shape = ib->GetShape(dout);
|
||||
ShapeVector expanded_shape;
|
||||
ShapeVector tile_mults;
|
||||
ShapeVector one_vec{1};
|
||||
if (data_format == "NCHW") {
|
||||
// expanded_shape = np.concatenate([np.ones_like(shape[:1]), bias_shape, np.ones_like(shape[2:])], axis=0)
|
||||
expanded_shape = one_vec + dout_shape;
|
||||
expanded_shape = dy_shape.size() > 2 ? expanded_shape + ShapeVector(1, dy_shape.size() - 2) : expanded_shape;
|
||||
// tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0)
|
||||
ShapeVector tmp{dy_shape[0], 1};
|
||||
tile_mults = tmp;
|
||||
tile_mults = dy_shape.size() > 2 ? tile_mults + ShapeVector(dy_shape.begin() + 2, dy_shape.end()) : tile_mults;
|
||||
} else {
|
||||
// expanded_shape = np.concatenate([np.ones_like(shape[:-1]), bias_shape], axis=0)
|
||||
expanded_shape = ShapeVector(1, dy_shape.size() - 1) + dout_shape;
|
||||
// tile_mults = np.concatenate([shape[:-1], [1]], axis=0)
|
||||
tile_mults = ShapeVector(dy_shape.begin(), dy_shape.end() - 1) + one_vec;
|
||||
}
|
||||
auto shape_func = [data_format](const ShapeArray &inputs) -> ShapeArray {
|
||||
ShapeVector expanded_shape;
|
||||
ShapeVector tile_mults;
|
||||
ShapeVector one_vec{1};
|
||||
auto dy_shape = inputs.at(0);
|
||||
auto dout_shape = inputs.at(1);
|
||||
if (data_format == "NCHW") {
|
||||
// expanded_shape = np.concatenate([np.ones_like(shape[:1]), bias_shape, np.ones_like(shape[2:])], axis=0)
|
||||
expanded_shape = one_vec + dout_shape;
|
||||
expanded_shape = dy_shape.size() > 2 ? expanded_shape + ShapeVector(1, dy_shape.size() - 2) : expanded_shape;
|
||||
// tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0)
|
||||
ShapeVector tmp{dy_shape[0], 1};
|
||||
tile_mults = tmp;
|
||||
tile_mults = dy_shape.size() > 2 ? tile_mults + ShapeVector(dy_shape.begin() + 2, dy_shape.end()) : tile_mults;
|
||||
} else {
|
||||
// expanded_shape = np.concatenate([np.ones_like(shape[:-1]), bias_shape], axis=0)
|
||||
expanded_shape = ShapeVector(1, dy_shape.size() - 1) + dout_shape;
|
||||
// tile_mults = np.concatenate([shape[:-1], [1]], axis=0)
|
||||
tile_mults = ShapeVector(dy_shape.begin(), dy_shape.end() - 1) + one_vec;
|
||||
}
|
||||
return {expanded_shape, tile_mults};
|
||||
};
|
||||
|
||||
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &) -> ShapeVector {
|
||||
int64_t x_rank =
|
||||
IsDynamicRank(inputs.at(0)) ? -1 : static_cast<int64_t>(inputs.at(0).size() + inputs.at(1).size() - 1);
|
||||
int64_t y_rank = IsDynamicRank(inputs.at(1)) ? -1 : static_cast<int64_t>(inputs.at(0).size());
|
||||
return {x_rank, y_rank};
|
||||
};
|
||||
|
||||
auto res = ib->ShapeCalc({dy, dout}, shape_func, infer_func, {});
|
||||
NodePtr expanded_shape = res[0];
|
||||
NodePtr tile_mults = res[1];
|
||||
|
||||
auto expanded_grad = ib->Reshape(dout, expanded_shape);
|
||||
auto tiled_grad = ib->Tile(expanded_grad, tile_mults);
|
||||
return {tiled_grad};
|
||||
|
@ -1103,7 +1153,6 @@ REG_BPROP_BUILDER("Softmax").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
|||
|
||||
NodePtr reverse_axis =
|
||||
IsDynamicRank(shp) ? ib->ShapeCalc({x}, shape_func, infer_func)[0] : ib->Value(GetTransposeAxis(shp, one_axis));
|
||||
|
||||
out = ib->Transpose(out, reverse_axis);
|
||||
dout = ib->Transpose(dout, reverse_axis);
|
||||
auto dx = ib->Mul(out, ib->Sub(dout, ib->ReduceSum(ib->Mul(out, dout), ShapeVector{-1}, true)));
|
||||
|
@ -1227,8 +1276,7 @@ REG_BPROP_BUILDER("Conv2DBackpropFilter").SetUnusedInputs({i2, i3}).SetBody(BODY
|
|||
auto x = ib->GetInput(kIndex1);
|
||||
auto filter_size = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto shp = ib->GetShape(x);
|
||||
auto x_shape = IsDynamic(shp) ? ib->Emit("Shape", {x}) : ib->Value<ShapeVector>(shp);
|
||||
auto x_shape = ib->Shape(x);
|
||||
auto dw_dx = ib->Emit(kConv2DBackpropInputOpName, {dy, dout, x_shape},
|
||||
{{"mode", ib->GetAttr("mode")},
|
||||
{"dilation", ib->GetAttr("dilation")},
|
||||
|
@ -1277,8 +1325,15 @@ REG_BPROP_BUILDER("BCEWithLogitsLoss").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib
|
|||
grad_target = ib->Mul(grad_target, weight);
|
||||
|
||||
if (reduction == "mean") {
|
||||
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(dx), ib->GetDtype(dx)));
|
||||
grad_target = ib->RealDiv(grad_target, ib->Tensor(ib->GetSize(target), ib->GetDtype(grad_target)));
|
||||
if (IsDynamic(ib->GetShape(dx))) {
|
||||
auto res = ib->DynSize(dx, ib->GetDtype(dx));
|
||||
dx = ib->RealDiv(dx, res);
|
||||
auto res2 = ib->DynSize(target, ib->GetDtype(grad_target));
|
||||
grad_target = ib->RealDiv(grad_target, res2);
|
||||
} else {
|
||||
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(dx), ib->GetDtype(dx)));
|
||||
grad_target = ib->RealDiv(grad_target, ib->Tensor(ib->GetSize(target), ib->GetDtype(grad_target)));
|
||||
}
|
||||
}
|
||||
return {dx, grad_target, ib->ZerosLike(weight), ib->ZerosLike(pos_weight)};
|
||||
});
|
||||
|
@ -1291,7 +1346,12 @@ REG_BPROP_BUILDER("KLDivLoss").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
|||
NodePtr dx;
|
||||
if (reduction == "mean") {
|
||||
dx = ib->Emit("KLDivLossGrad", {dout, x, y}, {{"reduction", MakeValue("sum")}});
|
||||
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(x), ib->GetDtype(dx)));
|
||||
if (IsDynamic(ib->GetShape(x))) {
|
||||
auto res = ib->DynSize(dx, ib->GetDtype(dx));
|
||||
dx = ib->RealDiv(dx, res);
|
||||
} else {
|
||||
dx = ib->RealDiv(dx, ib->Tensor(ib->GetSize(x), ib->GetDtype(dx)));
|
||||
}
|
||||
} else {
|
||||
dx = ib->Emit("KLDivLossGrad", {dout, x, y}, {{"reduction", MakeValue(reduction)}});
|
||||
}
|
||||
|
@ -1499,9 +1559,7 @@ REG_BPROP_BUILDER("NthElement").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
|||
REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shp = ib->GetShape(x);
|
||||
auto x_shape =
|
||||
IsDynamic(shp) ? ib->Emit("TupleToTensor", {ib->Emit("Shape", {x}), ib->Value(kInt64)}) : ib->Tensor(shp);
|
||||
auto x_shape = ib->Shape(x, true);
|
||||
auto dx = ib->Emit("AdaptiveAvgPool3DGrad", {dout, ib->Cast(x_shape, kInt32)});
|
||||
return {dx};
|
||||
});
|
||||
|
@ -1574,7 +1632,7 @@ REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib)
|
|||
REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto orig_input_shape = ib->Value<ShapeVector>(ib->GetShape(x));
|
||||
auto orig_input_shape = ib->Shape(x);
|
||||
auto dx = ib->Emit("AvgPoolGradV1", {orig_input_shape, dout},
|
||||
{
|
||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
|
|
Loading…
Reference in New Issue