forked from mindspore-Ecosystem/mindspore
!47924 fix bias add grad error
Merge pull request !47924 from lianliguang/bprop-mindir
This commit is contained in:
commit
2b05fd48cb
|
@ -1622,6 +1622,29 @@ EvalResultPtr GetEvaluatedValueForFuncGraphAttrOrMethod(const AbstractBasePtrLis
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForPrimitiveAttr(const AbstractBasePtrList &args_abs_list,
|
||||
const AbstractFunctionPtr &data_args,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(data_args);
|
||||
if (!data_args->isa<PrimitiveAbstractClosure>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prim_abs = dyn_cast_ptr<PrimitiveAbstractClosure>(data_args);
|
||||
const auto &prim = prim_abs->prim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
constexpr auto item_index = 1;
|
||||
auto item_arg = args_abs_list.at(item_index);
|
||||
MS_EXCEPTION_IF_NULL(item_arg);
|
||||
auto attr_name = GetValue<string>(item_arg->BuildValue());
|
||||
auto value = prim->GetAttr(attr_name);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(INFO) << "The Primitive :" << prim->ToString() << "has not attr " << attr_name;
|
||||
MS_LOG(INFO) << "PrimAttr :" << prim->GetAttrsText();
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<EvalResult>(value->ToAbstract(), nullptr);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtr &data_args,
|
||||
const AbstractBasePtr &item_args,
|
||||
|
@ -1769,6 +1792,27 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
EvalResultPtr GetFuncAbstractAttr(const AbstractFunctionPtr &data_args, const AbstractBasePtrList &args_abs_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
if (data_args == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// Get attribute or method of PartialAbstractClosure, the object is class object decorated with 'jit_class'.
|
||||
auto class_value = GetMsClassObject(data_args);
|
||||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(args_abs_list, class_value, out_conf);
|
||||
}
|
||||
// Get attribute or method of FuncGraphAbstractClosure, the object could be Cell/ms_class object.
|
||||
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
|
||||
if (data_func_graph != nullptr) {
|
||||
auto res = GetEvaluatedValueForFuncGraphAttrOrMethod(args_abs_list, data_func_graph->func_graph(), out_conf);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return GetEvaluatedValueForPrimitiveAttr(args_abs_list, data_args, out_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
// Inputs: namespace and its static function; or class and its member function
|
||||
|
@ -1821,21 +1865,13 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
}
|
||||
}
|
||||
|
||||
// Get attribute or method of PartialAbstractClosure, the object is class object decorated with 'jit_class'.
|
||||
auto class_value = GetMsClassObject(data_args);
|
||||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(args_abs_list, class_value, out_conf);
|
||||
}
|
||||
// Get attribute or method of FuncGraphAbstractClosure, the object could be Cell/ms_class object.
|
||||
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
|
||||
if (data_func_graph != nullptr) {
|
||||
auto res = GetEvaluatedValueForFuncGraphAttrOrMethod(args_abs_list, data_func_graph->func_graph(), out_conf);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
auto res = GetFuncAbstractAttr(data_args->cast<AbstractFunctionPtr>(), args_abs_list, out_conf);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Get attribute or method of AdapterTensor object.
|
||||
auto res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
|
||||
res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue