!46344 fix mindir subcell bug

Merge pull request !46344 from lianliguang/r1.10
This commit is contained in:
i-robot 2022-12-02 09:56:10 +00:00 committed by Gitee
commit 46d371da92
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 10 additions and 69 deletions

View File

@ -258,9 +258,13 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const Func
if (resource->is_load()) {
// If the primitive is not defined in front end, keep the inferred value loaded from MindIR.
auto primitive = GetCNodePrimitive(node);
if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
continue;
if (primitive != nullptr) {
auto is_load = primitive->GetAttr("is_load");
if (abstract::GetPrimEvaluator(primitive, engine) == nullptr && is_load != nullptr &&
GetValue<bool>(is_load)) {
MS_LOG(ERROR) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
continue;
}
}
if (!clear && node->isa<Parameter>()) {
continue;
@ -325,7 +329,7 @@ FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_g
return res;
}
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &resource) {
void SetLoadFlag(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -338,65 +342,7 @@ const FuncGraphPtr GetLoadedGraph(const ResourcePtr &resource) {
loaded_graph = graph;
loaded_graph_num += 1;
resource->set_is_load(true);
}
}
if (loaded_graph_num == 0) {
return nullptr;
}
if (loaded_graph_num == 1) {
return loaded_graph;
}
MS_LOG(EXCEPTION) << "The loaded sub graph currently should be less than 2, but got " << loaded_graph_num;
}
void CheckRootInputShapeAndType(const ResourcePtr &resource, const FuncGraphPtr &loaded_graph) {
MS_EXCEPTION_IF_NULL(resource);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
FuncGraphPtr root_graph = *(manager->roots().begin());
auto root_inputs = root_graph->get_inputs();
auto loaded_inputs = loaded_graph->get_inputs();
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
size_t root_inputs_num = root_inputs.size();
size_t loaded_inputs_num = loaded_inputs.size();
if (root_inputs_num != loaded_inputs_num) {
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
<< loaded_inputs_num;
}
for (size_t index = 0; index < root_inputs_num; index++) {
auto root_input = root_inputs[index];
auto loaded_input = loaded_inputs[index];
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
MS_LOG(DEBUG) << "root_input abstract[" << index
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
MS_LOG(DEBUG) << "loaded_input abstract [" << index
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
MS_EXCEPTION_IF_NULL(root_shape);
MS_EXCEPTION_IF_NULL(loaded_shape);
MS_EXCEPTION_IF_NULL(root_type);
MS_EXCEPTION_IF_NULL(loaded_type);
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
if (!shapeEqu) {
MS_EXCEPTION(ValueError) << "The " << index
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
<< ", input shape of loaded graph: " << loaded_shape->ToString();
}
if (root_type->type_id() != loaded_type->type_id()) {
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
<< ", input type of loaded graph: " << loaded_type->ToString();
return;
}
}
}
@ -784,8 +730,7 @@ bool AbstractSpecializeAction(const ResourcePtr &resource) {
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
}
// Get original loaded graph to check inputs later
auto loaded_graph_ptr = GetLoadedGraph(resource);
SetLoadFlag(resource);
// Abstract analyze
auto engine = resource->engine();
@ -810,10 +755,6 @@ bool AbstractSpecializeAction(const ResourcePtr &resource) {
}
}
}
// Check input after abstract when there is a loaded graph
if (loaded_graph_ptr != nullptr) {
CheckRootInputShapeAndType(resource, loaded_graph_ptr);
}
UpdateFuncGraphParameter(new_fg, resource->arguments());
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);