!47924 fix bias add grad error

Merge pull request !47924 from lianliguang/bprop-mindir
This commit is contained in:
i-robot 2023-01-30 01:22:57 +00:00 committed by Gitee
commit 2b05fd48cb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 49 additions and 13 deletions

View File

@ -1622,6 +1622,29 @@ EvalResultPtr GetEvaluatedValueForFuncGraphAttrOrMethod(const AbstractBasePtrLis
return nullptr;
}
EvalResultPtr GetEvaluatedValueForPrimitiveAttr(const AbstractBasePtrList &args_abs_list,
const AbstractFunctionPtr &data_args,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(data_args);
if (!data_args->isa<PrimitiveAbstractClosure>()) {
return nullptr;
}
auto prim_abs = dyn_cast_ptr<PrimitiveAbstractClosure>(data_args);
const auto &prim = prim_abs->prim();
MS_EXCEPTION_IF_NULL(prim);
constexpr auto item_index = 1;
auto item_arg = args_abs_list.at(item_index);
MS_EXCEPTION_IF_NULL(item_arg);
auto attr_name = GetValue<string>(item_arg->BuildValue());
auto value = prim->GetAttr(attr_name);
if (value == nullptr) {
MS_LOG(INFO) << "The Primitive :" << prim->ToString() << "has not attr " << attr_name;
MS_LOG(INFO) << "PrimAttr :" << prim->GetAttrsText();
return nullptr;
}
return std::make_shared<EvalResult>(value->ToAbstract(), nullptr);
}
EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtr &data_args,
const AbstractBasePtr &item_args,
@ -1769,6 +1792,27 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
return nullptr;
}
EvalResultPtr GetFuncAbstractAttr(const AbstractFunctionPtr &data_args, const AbstractBasePtrList &args_abs_list,
const AnfNodeConfigPtr &out_conf) {
if (data_args == nullptr) {
return nullptr;
}
// Get attribute or method of PartialAbstractClosure, the object is class object decorated with 'jit_class'.
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(args_abs_list, class_value, out_conf);
}
// Get attribute or method of FuncGraphAbstractClosure, the object could be Cell/ms_class object.
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
if (data_func_graph != nullptr) {
auto res = GetEvaluatedValueForFuncGraphAttrOrMethod(args_abs_list, data_func_graph->func_graph(), out_conf);
if (res != nullptr) {
return res;
}
}
return GetEvaluatedValueForPrimitiveAttr(args_abs_list, data_args, out_conf);
}
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
// Inputs: namespace and its static function; or class and its member function
@ -1821,21 +1865,13 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
}
}
// Get attribute or method of PartialAbstractClosure, the object is class object decorated with 'jit_class'.
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(args_abs_list, class_value, out_conf);
}
// Get attribute or method of FuncGraphAbstractClosure, the object could be Cell/ms_class object.
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
if (data_func_graph != nullptr) {
auto res = GetEvaluatedValueForFuncGraphAttrOrMethod(args_abs_list, data_func_graph->func_graph(), out_conf);
if (res != nullptr) {
return res;
}
auto res = GetFuncAbstractAttr(data_args->cast<AbstractFunctionPtr>(), args_abs_list, out_conf);
if (res != nullptr) {
return res;
}
// Get attribute or method of AdapterTensor object.
auto res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
if (res != nullptr) {
return res;
}