!45166 refine Transpose emitter
Merge pull request !45166 from looop5/refine_bprop_transpose
This commit is contained in:
commit
2f586938b6
|
@ -348,10 +348,10 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
|||
}
|
||||
out_shp = ib->GetShape(dout);
|
||||
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_v);
|
||||
auto values_transpose = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(perm_1)});
|
||||
auto values_transpose = ib->Transpose(dout, perm_1);
|
||||
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
|
||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
|
||||
auto params_grad = ib->Emit("Transpose", {tmp, ib->Value<ShapeVector>(perm_2)});
|
||||
auto params_grad = ib->Transpose(tmp, perm_2);
|
||||
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
|
||||
}
|
||||
|
||||
|
|
|
@ -85,6 +85,19 @@ NodePtr Emitter::BatchMatMul(const NodePtr &a, const NodePtr &b, bool transpose_
|
|||
{"transpose_b", MakeValue(transpose_b)}});
|
||||
}
|
||||
|
||||
NodePtr Emitter::Transpose(const NodePtr &node, const ShapeVector &perm) const {
|
||||
// perm like [0, 1, 2, 3] does not need transpose.
|
||||
auto n = SizeToLong(perm.size());
|
||||
for (size_t i = 0; i < perm.size(); ++i) {
|
||||
// perm value may be negative, e.g. [0, -3, 2, 3] is equal to [0, 1, 2, 3]
|
||||
auto perm_i = perm[i] < 0 ? (perm[i] + n) : perm[i];
|
||||
if (perm_i != static_cast<int64_t>(i)) {
|
||||
return Emit(kTransposeOpName, {node, Value(perm)});
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->get<ValueNodePtr>();
|
||||
|
|
|
@ -56,9 +56,7 @@ class Emitter {
|
|||
NodePtr Reciprocal(const NodePtr &node) const { return Emit(prim::kReciprocal, {node}); }
|
||||
NodePtr Square(const NodePtr &node) const { return Emit(prim::kSquare, {node}); }
|
||||
NodePtr Sign(const NodePtr &node) const { return Emit(prim::kPrimSign->name(), {node}); }
|
||||
NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) const {
|
||||
return Emit(kTransposeOpName, {node, Value(perm)});
|
||||
}
|
||||
NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) const;
|
||||
NodePtr Tile(const NodePtr &node, const ShapeVector &multiples) const {
|
||||
return Emit(kTileOpName, {node, Value(multiples)});
|
||||
}
|
||||
|
|
|
@ -180,10 +180,10 @@ REG_BPROP_BUILDER("SparseGatherV2").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
out_shp = ib->GetShape(dout);
|
||||
ind_shp = ib->GetShape(indices);
|
||||
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_int);
|
||||
auto values_transpose = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(perm_1)});
|
||||
auto values_transpose = ib->Transpose(dout, perm_1);
|
||||
auto params_grad = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_int])});
|
||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_int);
|
||||
params_grad = ib->Emit("Transpose", {params_grad, ib->Value<ShapeVector>(perm_2)});
|
||||
params_grad = ib->Transpose(params_grad, perm_2);
|
||||
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
||||
});
|
||||
|
||||
|
@ -207,7 +207,7 @@ REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
auto top_k_input = input_x;
|
||||
if ((static_cast<size_t>(axis + 1) != rank)) {
|
||||
transposition = GetTransposition(axis, rank);
|
||||
top_k_input = ib->Emit("Transpose", {input_x, ib->Value<ShapeVector>(transposition)});
|
||||
top_k_input = ib->Transpose(input_x, transposition);
|
||||
}
|
||||
auto tmp = ib->Emit("TopK", {top_k_input, ib->Value<int64_t>(k)});
|
||||
auto indices = ib->TupleGetItem(tmp, 1);
|
||||
|
@ -228,12 +228,12 @@ REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
auto x_shape_1d = ib->Value<ShapeVector>({x_size});
|
||||
NodePtr dx = nullptr;
|
||||
if (!transposition.empty()) {
|
||||
auto invert_perm = ib->Value<ShapeVector>(InvertPermutation(transposition));
|
||||
dvalue = ib->Emit("Transpose", {dvalue, invert_perm});
|
||||
auto invert_perm = InvertPermutation(transposition);
|
||||
dvalue = ib->Transpose(dvalue, invert_perm);
|
||||
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
|
||||
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
|
||||
auto out_grad = ib->Reshape(scatter, top_k_input_shape);
|
||||
dx = ib->Emit("Transpose", {out_grad, invert_perm});
|
||||
dx = ib->Transpose(out_grad, invert_perm);
|
||||
} else {
|
||||
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
|
||||
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
|
||||
|
@ -817,7 +817,7 @@ REG_BPROP_BUILDER("Transpose").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
(void)std::transform(tmp_perm.begin(), tmp_perm.end(), std::back_inserter(new_perm),
|
||||
[&tmp_perm](const int64_t v) { return v >= 0 ? v : v + tmp_perm.size(); });
|
||||
auto res_perm = InvertPermutation(new_perm);
|
||||
return {ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(res_perm)}), ib->ZerosLike(perm)};
|
||||
return {ib->Transpose(dout, res_perm), ib->ZerosLike(perm)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Slice").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
|
|
|
@ -565,8 +565,8 @@ REG_BPROP_BUILDER("Cdist").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
}
|
||||
perm.push_back(dout_dim - 1);
|
||||
perm.push_back(dout_dim - 2);
|
||||
auto dout_transpose = ib->Emit("Transpose", {dout, ib->Tensor(perm)});
|
||||
auto out_transpose = ib->Emit("Transpose", {out, ib->Tensor(perm)});
|
||||
auto dout_transpose = ib->Transpose(dout, perm);
|
||||
auto out_transpose = ib->Transpose(out, perm);
|
||||
auto dx = ib->Emit("CdistGrad", {dout, input_x, input_y, out}, {{"p", ib->GetAttr("p")}});
|
||||
auto dy = ib->Emit("CdistGrad", {dout_transpose, input_y, input_x, out_transpose}, {{"p", ib->GetAttr("p")}});
|
||||
return {dx, dy};
|
||||
|
@ -926,7 +926,7 @@ REG_BPROP_BUILDER("ReduceProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims);
|
||||
auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)});
|
||||
auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis));
|
||||
auto permuted = ib->Emit("Transpose", {x, ib->Value<ShapeVector>(perm)});
|
||||
auto permuted = ib->Transpose(x, perm);
|
||||
auto permuted_shape = ib->GetShape(permuted);
|
||||
auto reshaped = ib->Reshape(permuted, pack_shape);
|
||||
auto left = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
|
||||
|
@ -934,7 +934,7 @@ REG_BPROP_BUILDER("ReduceProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
auto right = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
|
||||
{{"exclusive", MakeValue(true)}, {"reverse", MakeValue(true)}});
|
||||
auto y = ib->Reshape(ib->Mul(left, right), permuted_shape);
|
||||
out = ib->Mul((ib->Emit("Transpose", {y, ib->Value<ShapeVector>(InvertPermutation(perm))})), grad);
|
||||
out = ib->Mul(ib->Transpose(y, InvertPermutation(perm)), grad);
|
||||
auto dx = ib->Reshape(out, input_shape);
|
||||
return {dx, ib->ZerosLike(axis)};
|
||||
});
|
||||
|
@ -1143,7 +1143,7 @@ REG_BPROP_BUILDER("MatrixExp").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
auto input_perm = Range(x_len);
|
||||
input_perm[x_len - 2] = x_len - 1;
|
||||
input_perm[x_len - 1] = x_len - 2;
|
||||
auto x_transpose = ib->Emit("Transpose", {x, ib->EmitValue(MakeValue(input_perm))});
|
||||
auto x_transpose = ib->Transpose(x, input_perm);
|
||||
auto zero_matrix = ib->ZerosLike(x);
|
||||
zero_matrix = ib->Cast(zero_matrix, ib->GetDtype(dout));
|
||||
auto meta_grad_up = ib->Emit("Concat", {ib->MakeTuple({x_transpose, dout})}, {{"axis", MakeValue<int64_t>(-1)}});
|
||||
|
@ -1186,14 +1186,14 @@ REG_BPROP_BUILDER("CholeskyInverse").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
input_x = ib->Cast(input_x, kFloat32);
|
||||
out = ib->Cast(out, kFloat32);
|
||||
dout = ib->Cast(dout, kFloat32);
|
||||
auto common_term = ib->Add(dout, ib->Emit("Transpose", {dout, ib->EmitValue(MakeValue(input_perm))}));
|
||||
auto common_term = ib->Add(dout, ib->Transpose(dout, input_perm));
|
||||
common_term = ib->Cast(common_term, kFloat32);
|
||||
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
|
||||
DealWithUpper(common_term);
|
||||
dx = ib->Cast(dx, kFloat64);
|
||||
return {dx};
|
||||
}
|
||||
auto common_term = ib->Add(dout, ib->Emit("Transpose", {dout, ib->EmitValue(MakeValue(input_perm))}));
|
||||
auto common_term = ib->Add(dout, ib->Transpose(dout, input_perm));
|
||||
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
|
||||
DealWithUpper(common_term);
|
||||
return {dx};
|
||||
|
@ -1386,7 +1386,7 @@ REG_BPROP_BUILDER("TridiagonalMatMul").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
}
|
||||
perm.emplace_back(rank - 1);
|
||||
perm.emplace_back(rank - 2);
|
||||
return ib->Emit("Transpose", {x, ib->Value<ShapeVector>(perm)});
|
||||
return ib->Transpose(x, perm);
|
||||
};
|
||||
auto superdiag = ib->GetInput(kIndex0);
|
||||
auto maindiag = ib->GetInput(kIndex1);
|
||||
|
|
|
@ -978,11 +978,11 @@ REG_BPROP_BUILDER("Softmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shp = ib->GetShape(x);
|
||||
auto reverse_axis = GetTransposeAxis(shp, one_axis);
|
||||
out = ib->Emit("Transpose", {out, ib->Value<ShapeVector>(reverse_axis)});
|
||||
dout = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(reverse_axis)});
|
||||
out = ib->Transpose(out, reverse_axis);
|
||||
dout = ib->Transpose(dout, reverse_axis);
|
||||
ShapeVector reduce_axis = {-1};
|
||||
auto dx = ib->Mul(out, ib->Sub(dout, ib->ReduceSum(ib->Mul(out, dout), reduce_axis, true)));
|
||||
dx = ib->Emit("Transpose", {dx, ib->Value<ShapeVector>(reverse_axis)});
|
||||
dx = ib->Transpose(dx, reverse_axis);
|
||||
return {dx};
|
||||
});
|
||||
|
||||
|
|
|
@ -137,8 +137,7 @@ REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetBody([](const BpropIRBuilder *ib
|
|||
auto dout = ib->GetInput(kIndex5);
|
||||
auto dense_grad = ib->Emit("SparseTensorDenseMatmul", {indices, values, dense_shape, dout},
|
||||
{{"adjoint_st", MakeValue(!adj_s)}, {"adjoint_dt", MakeValue(adj_d)}});
|
||||
std::vector<int64_t> perm_value{1, 0};
|
||||
auto perm = ib->Tensor(perm_value);
|
||||
auto perm = ib->Value<ShapeVector>({1, 0});
|
||||
if (adj_d) {
|
||||
dense_grad = ib->Emit("Transpose", {dense_grad, perm});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue