forked from mindspore-Ecosystem/mindspore
!46344 fix mindir subcell bug
Merge pull request !46344 from lianliguang/r1.10
This commit is contained in:
commit
46d371da92
|
@ -258,9 +258,13 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const Func
|
|||
if (resource->is_load()) {
|
||||
// If the primitive is not defined in front end, keep the inferred value loaded from MindIR.
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
|
||||
MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
|
||||
continue;
|
||||
if (primitive != nullptr) {
|
||||
auto is_load = primitive->GetAttr("is_load");
|
||||
if (abstract::GetPrimEvaluator(primitive, engine) == nullptr && is_load != nullptr &&
|
||||
GetValue<bool>(is_load)) {
|
||||
MS_LOG(ERROR) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!clear && node->isa<Parameter>()) {
|
||||
continue;
|
||||
|
@ -325,7 +329,7 @@ FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_g
|
|||
return res;
|
||||
}
|
||||
|
||||
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &resource) {
|
||||
void SetLoadFlag(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -338,65 +342,7 @@ const FuncGraphPtr GetLoadedGraph(const ResourcePtr &resource) {
|
|||
loaded_graph = graph;
|
||||
loaded_graph_num += 1;
|
||||
resource->set_is_load(true);
|
||||
}
|
||||
}
|
||||
if (loaded_graph_num == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (loaded_graph_num == 1) {
|
||||
return loaded_graph;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The loaded sub graph currently should be less than 2, but got " << loaded_graph_num;
|
||||
}
|
||||
|
||||
void CheckRootInputShapeAndType(const ResourcePtr &resource, const FuncGraphPtr &loaded_graph) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
FuncGraphPtr root_graph = *(manager->roots().begin());
|
||||
auto root_inputs = root_graph->get_inputs();
|
||||
auto loaded_inputs = loaded_graph->get_inputs();
|
||||
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
|
||||
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
|
||||
size_t root_inputs_num = root_inputs.size();
|
||||
size_t loaded_inputs_num = loaded_inputs.size();
|
||||
if (root_inputs_num != loaded_inputs_num) {
|
||||
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
|
||||
<< loaded_inputs_num;
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < root_inputs_num; index++) {
|
||||
auto root_input = root_inputs[index];
|
||||
auto loaded_input = loaded_inputs[index];
|
||||
|
||||
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
|
||||
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
|
||||
MS_LOG(DEBUG) << "root_input abstract[" << index
|
||||
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
|
||||
MS_LOG(DEBUG) << "loaded_input abstract [" << index
|
||||
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
|
||||
|
||||
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
|
||||
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
|
||||
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
|
||||
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(root_shape);
|
||||
MS_EXCEPTION_IF_NULL(loaded_shape);
|
||||
MS_EXCEPTION_IF_NULL(root_type);
|
||||
MS_EXCEPTION_IF_NULL(loaded_type);
|
||||
|
||||
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
|
||||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
|
||||
if (!shapeEqu) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index
|
||||
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
|
||||
<< ", input shape of loaded graph: " << loaded_shape->ToString();
|
||||
}
|
||||
if (root_type->type_id() != loaded_type->type_id()) {
|
||||
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
|
||||
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
|
||||
<< ", input type of loaded graph: " << loaded_type->ToString();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -784,8 +730,7 @@ bool AbstractSpecializeAction(const ResourcePtr &resource) {
|
|||
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
|
||||
}
|
||||
|
||||
// Get original loaded graph to check inputs later
|
||||
auto loaded_graph_ptr = GetLoadedGraph(resource);
|
||||
SetLoadFlag(resource);
|
||||
|
||||
// Abstract analyze
|
||||
auto engine = resource->engine();
|
||||
|
@ -810,10 +755,6 @@ bool AbstractSpecializeAction(const ResourcePtr &resource) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// Check input after abstract when there is a loaded graph
|
||||
if (loaded_graph_ptr != nullptr) {
|
||||
CheckRootInputShapeAndType(resource, loaded_graph_ptr);
|
||||
}
|
||||
|
||||
UpdateFuncGraphParameter(new_fg, resource->arguments());
|
||||
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
|
||||
|
|
Loading…
Reference in New Issue