!45166 refine Transpose emitter
Merge pull request !45166 from looop5/refine_bprop_transpose
This commit is contained in:
commit
2f586938b6
|
@ -348,10 +348,10 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
||||||
}
|
}
|
||||||
out_shp = ib->GetShape(dout);
|
out_shp = ib->GetShape(dout);
|
||||||
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_v);
|
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_v);
|
||||||
auto values_transpose = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(perm_1)});
|
auto values_transpose = ib->Transpose(dout, perm_1);
|
||||||
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
|
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
|
||||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
|
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
|
||||||
auto params_grad = ib->Emit("Transpose", {tmp, ib->Value<ShapeVector>(perm_2)});
|
auto params_grad = ib->Transpose(tmp, perm_2);
|
||||||
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
|
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,19 @@ NodePtr Emitter::BatchMatMul(const NodePtr &a, const NodePtr &b, bool transpose_
|
||||||
{"transpose_b", MakeValue(transpose_b)}});
|
{"transpose_b", MakeValue(transpose_b)}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodePtr Emitter::Transpose(const NodePtr &node, const ShapeVector &perm) const {
|
||||||
|
// perm like [0, 1, 2, 3] does not need transpose.
|
||||||
|
auto n = SizeToLong(perm.size());
|
||||||
|
for (size_t i = 0; i < perm.size(); ++i) {
|
||||||
|
// perm value may be negative, e.g. [0, -3, 2, 3] is equal to [0, 1, 2, 3]
|
||||||
|
auto perm_i = perm[i] < 0 ? (perm[i] + n) : perm[i];
|
||||||
|
if (perm_i != static_cast<int64_t>(i)) {
|
||||||
|
return Emit(kTransposeOpName, {node, Value(perm)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
||||||
if (node->isa<ValueNode>()) {
|
if (node->isa<ValueNode>()) {
|
||||||
auto value_node = node->get<ValueNodePtr>();
|
auto value_node = node->get<ValueNodePtr>();
|
||||||
|
|
|
@ -56,9 +56,7 @@ class Emitter {
|
||||||
NodePtr Reciprocal(const NodePtr &node) const { return Emit(prim::kReciprocal, {node}); }
|
NodePtr Reciprocal(const NodePtr &node) const { return Emit(prim::kReciprocal, {node}); }
|
||||||
NodePtr Square(const NodePtr &node) const { return Emit(prim::kSquare, {node}); }
|
NodePtr Square(const NodePtr &node) const { return Emit(prim::kSquare, {node}); }
|
||||||
NodePtr Sign(const NodePtr &node) const { return Emit(prim::kPrimSign->name(), {node}); }
|
NodePtr Sign(const NodePtr &node) const { return Emit(prim::kPrimSign->name(), {node}); }
|
||||||
NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) const {
|
NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) const;
|
||||||
return Emit(kTransposeOpName, {node, Value(perm)});
|
|
||||||
}
|
|
||||||
NodePtr Tile(const NodePtr &node, const ShapeVector &multiples) const {
|
NodePtr Tile(const NodePtr &node, const ShapeVector &multiples) const {
|
||||||
return Emit(kTileOpName, {node, Value(multiples)});
|
return Emit(kTileOpName, {node, Value(multiples)});
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,10 +180,10 @@ REG_BPROP_BUILDER("SparseGatherV2").SetBody([](const BpropIRBuilder *ib) -> Node
|
||||||
out_shp = ib->GetShape(dout);
|
out_shp = ib->GetShape(dout);
|
||||||
ind_shp = ib->GetShape(indices);
|
ind_shp = ib->GetShape(indices);
|
||||||
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_int);
|
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_int);
|
||||||
auto values_transpose = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(perm_1)});
|
auto values_transpose = ib->Transpose(dout, perm_1);
|
||||||
auto params_grad = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_int])});
|
auto params_grad = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_int])});
|
||||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_int);
|
auto perm_2 = GenerateInverseIndex(x_shp, axis_int);
|
||||||
params_grad = ib->Emit("Transpose", {params_grad, ib->Value<ShapeVector>(perm_2)});
|
params_grad = ib->Transpose(params_grad, perm_2);
|
||||||
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||||
auto top_k_input = input_x;
|
auto top_k_input = input_x;
|
||||||
if ((static_cast<size_t>(axis + 1) != rank)) {
|
if ((static_cast<size_t>(axis + 1) != rank)) {
|
||||||
transposition = GetTransposition(axis, rank);
|
transposition = GetTransposition(axis, rank);
|
||||||
top_k_input = ib->Emit("Transpose", {input_x, ib->Value<ShapeVector>(transposition)});
|
top_k_input = ib->Transpose(input_x, transposition);
|
||||||
}
|
}
|
||||||
auto tmp = ib->Emit("TopK", {top_k_input, ib->Value<int64_t>(k)});
|
auto tmp = ib->Emit("TopK", {top_k_input, ib->Value<int64_t>(k)});
|
||||||
auto indices = ib->TupleGetItem(tmp, 1);
|
auto indices = ib->TupleGetItem(tmp, 1);
|
||||||
|
@ -228,12 +228,12 @@ REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||||
auto x_shape_1d = ib->Value<ShapeVector>({x_size});
|
auto x_shape_1d = ib->Value<ShapeVector>({x_size});
|
||||||
NodePtr dx = nullptr;
|
NodePtr dx = nullptr;
|
||||||
if (!transposition.empty()) {
|
if (!transposition.empty()) {
|
||||||
auto invert_perm = ib->Value<ShapeVector>(InvertPermutation(transposition));
|
auto invert_perm = InvertPermutation(transposition);
|
||||||
dvalue = ib->Emit("Transpose", {dvalue, invert_perm});
|
dvalue = ib->Transpose(dvalue, invert_perm);
|
||||||
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
|
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
|
||||||
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
|
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
|
||||||
auto out_grad = ib->Reshape(scatter, top_k_input_shape);
|
auto out_grad = ib->Reshape(scatter, top_k_input_shape);
|
||||||
dx = ib->Emit("Transpose", {out_grad, invert_perm});
|
dx = ib->Transpose(out_grad, invert_perm);
|
||||||
} else {
|
} else {
|
||||||
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
|
auto ind_expand = ib->Emit("ExpandDims", {ind, ib->Value<int64_t>(-1)});
|
||||||
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
|
auto scatter = ib->Emit("ScatterNd", {ind_expand, ib->Reshape(dvalue, {-1}), x_shape_1d});
|
||||||
|
@ -817,7 +817,7 @@ REG_BPROP_BUILDER("Transpose").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
||||||
(void)std::transform(tmp_perm.begin(), tmp_perm.end(), std::back_inserter(new_perm),
|
(void)std::transform(tmp_perm.begin(), tmp_perm.end(), std::back_inserter(new_perm),
|
||||||
[&tmp_perm](const int64_t v) { return v >= 0 ? v : v + tmp_perm.size(); });
|
[&tmp_perm](const int64_t v) { return v >= 0 ? v : v + tmp_perm.size(); });
|
||||||
auto res_perm = InvertPermutation(new_perm);
|
auto res_perm = InvertPermutation(new_perm);
|
||||||
return {ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(res_perm)}), ib->ZerosLike(perm)};
|
return {ib->Transpose(dout, res_perm), ib->ZerosLike(perm)};
|
||||||
});
|
});
|
||||||
|
|
||||||
REG_BPROP_BUILDER("Slice").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
REG_BPROP_BUILDER("Slice").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||||
|
|
|
@ -565,8 +565,8 @@ REG_BPROP_BUILDER("Cdist").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||||
}
|
}
|
||||||
perm.push_back(dout_dim - 1);
|
perm.push_back(dout_dim - 1);
|
||||||
perm.push_back(dout_dim - 2);
|
perm.push_back(dout_dim - 2);
|
||||||
auto dout_transpose = ib->Emit("Transpose", {dout, ib->Tensor(perm)});
|
auto dout_transpose = ib->Transpose(dout, perm);
|
||||||
auto out_transpose = ib->Emit("Transpose", {out, ib->Tensor(perm)});
|
auto out_transpose = ib->Transpose(out, perm);
|
||||||
auto dx = ib->Emit("CdistGrad", {dout, input_x, input_y, out}, {{"p", ib->GetAttr("p")}});
|
auto dx = ib->Emit("CdistGrad", {dout, input_x, input_y, out}, {{"p", ib->GetAttr("p")}});
|
||||||
auto dy = ib->Emit("CdistGrad", {dout_transpose, input_y, input_x, out_transpose}, {{"p", ib->GetAttr("p")}});
|
auto dy = ib->Emit("CdistGrad", {dout_transpose, input_y, input_x, out_transpose}, {{"p", ib->GetAttr("p")}});
|
||||||
return {dx, dy};
|
return {dx, dy};
|
||||||
|
@ -926,7 +926,7 @@ REG_BPROP_BUILDER("ReduceProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
||||||
auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims);
|
auto tile_scaling = TupleDiv(input_shape, output_shape_kept_dims);
|
||||||
auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)});
|
auto grad = ib->Emit("Tile", {dout, ib->Value<ShapeVector>(tile_scaling)});
|
||||||
auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis));
|
auto [pack_shape, perm] = SplitShapeIndex(input_shape, GetAxisValue(axis));
|
||||||
auto permuted = ib->Emit("Transpose", {x, ib->Value<ShapeVector>(perm)});
|
auto permuted = ib->Transpose(x, perm);
|
||||||
auto permuted_shape = ib->GetShape(permuted);
|
auto permuted_shape = ib->GetShape(permuted);
|
||||||
auto reshaped = ib->Reshape(permuted, pack_shape);
|
auto reshaped = ib->Reshape(permuted, pack_shape);
|
||||||
auto left = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
|
auto left = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
|
||||||
|
@ -934,7 +934,7 @@ REG_BPROP_BUILDER("ReduceProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
||||||
auto right = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
|
auto right = ib->Emit("CumProd", {reshaped, ib->Tensor(0, ib->GetDtype(axis))},
|
||||||
{{"exclusive", MakeValue(true)}, {"reverse", MakeValue(true)}});
|
{{"exclusive", MakeValue(true)}, {"reverse", MakeValue(true)}});
|
||||||
auto y = ib->Reshape(ib->Mul(left, right), permuted_shape);
|
auto y = ib->Reshape(ib->Mul(left, right), permuted_shape);
|
||||||
out = ib->Mul((ib->Emit("Transpose", {y, ib->Value<ShapeVector>(InvertPermutation(perm))})), grad);
|
out = ib->Mul(ib->Transpose(y, InvertPermutation(perm)), grad);
|
||||||
auto dx = ib->Reshape(out, input_shape);
|
auto dx = ib->Reshape(out, input_shape);
|
||||||
return {dx, ib->ZerosLike(axis)};
|
return {dx, ib->ZerosLike(axis)};
|
||||||
});
|
});
|
||||||
|
@ -1143,7 +1143,7 @@ REG_BPROP_BUILDER("MatrixExp").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
||||||
auto input_perm = Range(x_len);
|
auto input_perm = Range(x_len);
|
||||||
input_perm[x_len - 2] = x_len - 1;
|
input_perm[x_len - 2] = x_len - 1;
|
||||||
input_perm[x_len - 1] = x_len - 2;
|
input_perm[x_len - 1] = x_len - 2;
|
||||||
auto x_transpose = ib->Emit("Transpose", {x, ib->EmitValue(MakeValue(input_perm))});
|
auto x_transpose = ib->Transpose(x, input_perm);
|
||||||
auto zero_matrix = ib->ZerosLike(x);
|
auto zero_matrix = ib->ZerosLike(x);
|
||||||
zero_matrix = ib->Cast(zero_matrix, ib->GetDtype(dout));
|
zero_matrix = ib->Cast(zero_matrix, ib->GetDtype(dout));
|
||||||
auto meta_grad_up = ib->Emit("Concat", {ib->MakeTuple({x_transpose, dout})}, {{"axis", MakeValue<int64_t>(-1)}});
|
auto meta_grad_up = ib->Emit("Concat", {ib->MakeTuple({x_transpose, dout})}, {{"axis", MakeValue<int64_t>(-1)}});
|
||||||
|
@ -1186,14 +1186,14 @@ REG_BPROP_BUILDER("CholeskyInverse").SetBody([](const BpropIRBuilder *ib) -> Nod
|
||||||
input_x = ib->Cast(input_x, kFloat32);
|
input_x = ib->Cast(input_x, kFloat32);
|
||||||
out = ib->Cast(out, kFloat32);
|
out = ib->Cast(out, kFloat32);
|
||||||
dout = ib->Cast(dout, kFloat32);
|
dout = ib->Cast(dout, kFloat32);
|
||||||
auto common_term = ib->Add(dout, ib->Emit("Transpose", {dout, ib->EmitValue(MakeValue(input_perm))}));
|
auto common_term = ib->Add(dout, ib->Transpose(dout, input_perm));
|
||||||
common_term = ib->Cast(common_term, kFloat32);
|
common_term = ib->Cast(common_term, kFloat32);
|
||||||
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
|
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
|
||||||
DealWithUpper(common_term);
|
DealWithUpper(common_term);
|
||||||
dx = ib->Cast(dx, kFloat64);
|
dx = ib->Cast(dx, kFloat64);
|
||||||
return {dx};
|
return {dx};
|
||||||
}
|
}
|
||||||
auto common_term = ib->Add(dout, ib->Emit("Transpose", {dout, ib->EmitValue(MakeValue(input_perm))}));
|
auto common_term = ib->Add(dout, ib->Transpose(dout, input_perm));
|
||||||
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
|
common_term = ib->MatMul(out, ib->MatMul(common_term, out, false, false), false, false);
|
||||||
DealWithUpper(common_term);
|
DealWithUpper(common_term);
|
||||||
return {dx};
|
return {dx};
|
||||||
|
@ -1386,7 +1386,7 @@ REG_BPROP_BUILDER("TridiagonalMatMul").SetBody([](const BpropIRBuilder *ib) -> N
|
||||||
}
|
}
|
||||||
perm.emplace_back(rank - 1);
|
perm.emplace_back(rank - 1);
|
||||||
perm.emplace_back(rank - 2);
|
perm.emplace_back(rank - 2);
|
||||||
return ib->Emit("Transpose", {x, ib->Value<ShapeVector>(perm)});
|
return ib->Transpose(x, perm);
|
||||||
};
|
};
|
||||||
auto superdiag = ib->GetInput(kIndex0);
|
auto superdiag = ib->GetInput(kIndex0);
|
||||||
auto maindiag = ib->GetInput(kIndex1);
|
auto maindiag = ib->GetInput(kIndex1);
|
||||||
|
|
|
@ -978,11 +978,11 @@ REG_BPROP_BUILDER("Softmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
||||||
auto dout = ib->GetInput(kIndex2);
|
auto dout = ib->GetInput(kIndex2);
|
||||||
auto shp = ib->GetShape(x);
|
auto shp = ib->GetShape(x);
|
||||||
auto reverse_axis = GetTransposeAxis(shp, one_axis);
|
auto reverse_axis = GetTransposeAxis(shp, one_axis);
|
||||||
out = ib->Emit("Transpose", {out, ib->Value<ShapeVector>(reverse_axis)});
|
out = ib->Transpose(out, reverse_axis);
|
||||||
dout = ib->Emit("Transpose", {dout, ib->Value<ShapeVector>(reverse_axis)});
|
dout = ib->Transpose(dout, reverse_axis);
|
||||||
ShapeVector reduce_axis = {-1};
|
ShapeVector reduce_axis = {-1};
|
||||||
auto dx = ib->Mul(out, ib->Sub(dout, ib->ReduceSum(ib->Mul(out, dout), reduce_axis, true)));
|
auto dx = ib->Mul(out, ib->Sub(dout, ib->ReduceSum(ib->Mul(out, dout), reduce_axis, true)));
|
||||||
dx = ib->Emit("Transpose", {dx, ib->Value<ShapeVector>(reverse_axis)});
|
dx = ib->Transpose(dx, reverse_axis);
|
||||||
return {dx};
|
return {dx};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -137,8 +137,7 @@ REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetBody([](const BpropIRBuilder *ib
|
||||||
auto dout = ib->GetInput(kIndex5);
|
auto dout = ib->GetInput(kIndex5);
|
||||||
auto dense_grad = ib->Emit("SparseTensorDenseMatmul", {indices, values, dense_shape, dout},
|
auto dense_grad = ib->Emit("SparseTensorDenseMatmul", {indices, values, dense_shape, dout},
|
||||||
{{"adjoint_st", MakeValue(!adj_s)}, {"adjoint_dt", MakeValue(adj_d)}});
|
{{"adjoint_st", MakeValue(!adj_s)}, {"adjoint_dt", MakeValue(adj_d)}});
|
||||||
std::vector<int64_t> perm_value{1, 0};
|
auto perm = ib->Value<ShapeVector>({1, 0});
|
||||||
auto perm = ib->Tensor(perm_value);
|
|
||||||
if (adj_d) {
|
if (adj_d) {
|
||||||
dense_grad = ib->Emit("Transpose", {dense_grad, perm});
|
dense_grad = ib->Emit("Transpose", {dense_grad, perm});
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue