!69041 修复Addcdiv算子fp32+fp16场景下反向精度问题
Merge pull request !69041 from hedongdong/addcdiv_bug_master
This commit is contained in:
commit
ad7098c9a8
|
@ -139,7 +139,7 @@ ShapeVector MatrixDeterminantInferFunc(const ShapeArray &inputs, const HashSet<s
|
|||
return {IsDynamicRank(new_shape) ? -1 : SizeToLong(new_shape.size()) + 2};
|
||||
}
|
||||
|
||||
NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name, const std::unordered_set<TypeId> &type_list) {
|
||||
NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name) {
|
||||
auto input_data = ib->GetInput(kIndex0);
|
||||
auto x1 = ib->GetInput(kIndex1);
|
||||
auto x2 = ib->GetInput(kIndex2);
|
||||
|
@ -147,27 +147,26 @@ NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name, const
|
|||
auto dout = ib->GetInput(kIndex5);
|
||||
auto dinput_data = dout;
|
||||
auto dout_typeptr = ib->GetDtype(dout);
|
||||
bool need_cast = type_list.count(dout_typeptr->type_id()) > 0;
|
||||
if (need_cast) {
|
||||
input_data = ib->Cast(input_data, kFloat32);
|
||||
x1 = ib->Cast(x1, kFloat32);
|
||||
x2 = ib->Cast(x2, kFloat32);
|
||||
value = ib->Cast(value, kFloat32);
|
||||
if (op_name == "Addcdiv") {
|
||||
dinput_data = ib->Cast(dinput_data, kFloat32);
|
||||
}
|
||||
|
||||
input_data = ib->Cast(input_data, kFloat32);
|
||||
x1 = ib->Cast(x1, kFloat32);
|
||||
x2 = ib->Cast(x2, kFloat32);
|
||||
value = ib->Cast(value, kFloat32);
|
||||
if (op_name == "Addcdiv") {
|
||||
dinput_data = ib->Cast(dinput_data, kFloat32);
|
||||
}
|
||||
|
||||
NodePtr inner_out = nullptr;
|
||||
NodePtr dx1 = nullptr;
|
||||
NodePtr dx2 = nullptr;
|
||||
NodePtr dvalue = nullptr;
|
||||
if (op_name == "Addcdiv") {
|
||||
constexpr int64_t const_val = -2;
|
||||
inner_out = ib->Add((ib->Mul(value, ib->Cast(ib->Div(x1, x2), ib->GetDtype(x1)))), input_data);
|
||||
inner_out = ib->Add((ib->Mul(value, ib->Div(x1, x2))), input_data);
|
||||
dx2 =
|
||||
ib->Neg(ib->Mul(ib->Mul(ib->Mul(x1, value), ib->Pow(x2, ib->Tensor(const_val, ib->GetDtype(x2)))), dinput_data));
|
||||
dx1 = ib->Mul(dinput_data, ib->Cast(ib->Div(value, x2), ib->GetDtype(value)));
|
||||
dvalue = ib->Mul(dinput_data, ib->Cast(ib->Div(x1, x2), ib->GetDtype(x1)));
|
||||
dx1 = ib->Mul(dinput_data, ib->Div(value, x2));
|
||||
dvalue = ib->Mul(dinput_data, ib->Div(x1, x2));
|
||||
} else {
|
||||
dx1 = ib->Mul(dout, ib->Mul(value, x2));
|
||||
dx2 = ib->Mul(dout, ib->Mul(value, x1));
|
||||
|
@ -182,12 +181,12 @@ NodePtrList BpropAddcCommon(BpropBuilder *ib, const std::string &op_name, const
|
|||
dx2 = tmp_dx2[1];
|
||||
auto tmp_dvalue = BinopGradCommon(ib, inner_out, value, dout, dvalue);
|
||||
dvalue = tmp_dvalue[1];
|
||||
if (need_cast) {
|
||||
dinput_data = ib->Cast(dinput_data, dout_typeptr);
|
||||
dx1 = ib->Cast(dx1, dout_typeptr);
|
||||
dx2 = ib->Cast(dx2, dout_typeptr);
|
||||
dvalue = ib->Cast(dvalue, dout_typeptr);
|
||||
}
|
||||
|
||||
dinput_data = ib->Cast(dinput_data, ib->GetDtype(ib->GetInput(kIndex0)));
|
||||
dx1 = ib->Cast(dx1, ib->GetDtype(ib->GetInput(kIndex1)));
|
||||
dx2 = ib->Cast(dx2, ib->GetDtype(ib->GetInput(kIndex2)));
|
||||
dvalue = ib->Cast(dvalue, ib->GetDtype(ib->GetInput(kIndex3)));
|
||||
|
||||
return {dinput_data, dx1, dx2, dvalue};
|
||||
}
|
||||
|
||||
|
@ -2343,18 +2342,9 @@ REG_BPROP_BUILDER("AddV2").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
|||
return {BinopGradCommon(ib, x, y, bc_dx, bc_dy)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Addcdiv").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
std::unordered_set<TypeId> type_list{TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat16,
|
||||
TypeId::kNumberTypeFloat64};
|
||||
return BpropAddcCommon(ib, "Addcdiv", type_list);
|
||||
});
|
||||
REG_BPROP_BUILDER("Addcdiv").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { return BpropAddcCommon(ib, "Addcdiv"); });
|
||||
|
||||
REG_BPROP_BUILDER("Addcmul").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
std::unordered_set<TypeId> type_list{TypeId::kNumberTypeInt8, TypeId::kNumberTypeInt16,
|
||||
TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt8,
|
||||
TypeId::kNumberTypeFloat16, TypeId::kNumberTypeFloat64};
|
||||
return BpropAddcCommon(ib, "Addcmul", type_list);
|
||||
});
|
||||
REG_BPROP_BUILDER("Addcmul").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) { return BpropAddcCommon(ib, "Addcmul"); });
|
||||
|
||||
REG_BPROP_BUILDER("LpNorm").SetBody(BODYFUNC(ib) {
|
||||
auto p = GetValue<int64_t>(ib->GetAttr("p"));
|
||||
|
|
Loading…
Reference in New Issue