!69041 修复Addcdiv算子fp32+fp16场景下反向精度问题

Merge pull request !69041 from hedongdong/addcdiv_bug_master
This commit is contained in:
i-robot 2024-05-06 06:30:36 +00:00 committed by Gitee
commit ad7098c9a8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 20 additions and 30 deletions

View File

@ -139,7 +139,7 @@ ShapeVector MatrixDeterminantInferFunc(const ShapeArray &inputs, const HashSet<s
return {IsDynamicRank(new_shape) ? -1 : SizeToLong(new_shape.size()) + 2};
}
NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name, const std::unordered_set<TypeId> &type_list) {
NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name) {
auto input_data = ib->GetInput(kIndex0);
auto x1 = ib->GetInput(kIndex1);
auto x2 = ib->GetInput(kIndex2);
@ -147,27 +147,26 @@ NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name, const
auto dout = ib->GetInput(kIndex5);
auto dinput_data = dout;
auto dout_typeptr = ib->GetDtype(dout);
bool need_cast = type_list.count(dout_typeptr->type_id()) > 0;
if (need_cast) {
input_data = ib->Cast(input_data, kFloat32);
x1 = ib->Cast(x1, kFloat32);
x2 = ib->Cast(x2, kFloat32);
value = ib->Cast(value, kFloat32);
if (op_name == "Addcdiv") {
dinput_data = ib->Cast(dinput_data, kFloat32);
}
input_data = ib->Cast(input_data, kFloat32);
x1 = ib->Cast(x1, kFloat32);
x2 = ib->Cast(x2, kFloat32);
value = ib->Cast(value, kFloat32);
if (op_name == "Addcdiv") {
dinput_data = ib->Cast(dinput_data, kFloat32);
}
NodePtr inner_out = nullptr;
NodePtr dx1 = nullptr;
NodePtr dx2 = nullptr;
NodePtr dvalue = nullptr;
if (op_name == "Addcdiv") {
constexpr int64_t const_val = -2;
inner_out = ib->Add((ib->Mul(value, ib->Cast(ib->Div(x1, x2), ib->GetDtype(x1)))), input_data);
inner_out = ib->Add((ib->Mul(value, ib->Div(x1, x2))), input_data);
dx2 =
ib->Neg(ib->Mul(ib->Mul(ib->Mul(x1, value), ib->Pow(x2, ib->Tensor(const_val, ib->GetDtype(x2)))), dinput_data));
dx1 = ib->Mul(dinput_data, ib->Cast(ib->Div(value, x2), ib->GetDtype(value)));
dvalue = ib->Mul(dinput_data, ib->Cast(ib->Div(x1, x2), ib->GetDtype(x1)));
dx1 = ib->Mul(dinput_data, ib->Div(value, x2));
dvalue = ib->Mul(dinput_data, ib->Div(x1, x2));
} else {
dx1 = ib->Mul(dout, ib->Mul(value, x2));
dx2 = ib->Mul(dout, ib->Mul(value, x1));
@ -182,12 +181,12 @@ NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name, const
dx2 = tmp_dx2[1];
auto tmp_dvalue = BinopGradCommon(ib, inner_out, value, dout, dvalue);
dvalue = tmp_dvalue[1];
if (need_cast) {
dinput_data = ib->Cast(dinput_data, dout_typeptr);
dx1 = ib->Cast(dx1, dout_typeptr);
dx2 = ib->Cast(dx2, dout_typeptr);
dvalue = ib->Cast(dvalue, dout_typeptr);
}
dinput_data = ib->Cast(dinput_data, ib->GetDtype(ib->GetInput(kIndex0)));
dx1 = ib->Cast(dx1, ib->GetDtype(ib->GetInput(kIndex1)));
dx2 = ib->Cast(dx2, ib->GetDtype(ib->GetInput(kIndex2)));
dvalue = ib->Cast(dvalue, ib->GetDtype(ib->GetInput(kIndex3)));
return {dinput_data, dx1, dx2, dvalue};
}
@ -2343,18 +2342,9 @@ REG_BPROP_BUILDER("AddV2").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
});
REG_BPROP_BUILDER("Addcdiv").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
std::unordered_set<TypeId> type_list{TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat16,
TypeId::kNumberTypeFloat64};
return BpropAddcCommon(ib, "Addcdiv", type_list);
});
REG_BPROP_BUILDER("Addcdiv").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { return BpropAddcCommon(ib, "Addcdiv"); });
REG_BPROP_BUILDER("Addcmul").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
std::unordered_set<TypeId> type_list{TypeId::kNumberTypeInt8, TypeId::kNumberTypeInt16,
TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt8,
TypeId::kNumberTypeFloat16, TypeId::kNumberTypeFloat64};
return BpropAddcCommon(ib, "Addcmul", type_list);
});
REG_BPROP_BUILDER("Addcmul").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { return BpropAddcCommon(ib, "Addcmul"); });
REG_BPROP_BUILDER("LpNorm").SetBody(BODYFUNC(ib) {
auto p = GetValue<int64_t>(ib->GetAttr("p"));