!40418 Handling bug in CombineLikeGraphs.

Merge pull request !40418 from Margaret_wangrui/batchnorm_combine_like_graphs
This commit is contained in:
i-robot 2022-08-16 09:38:20 +00:00 committed by Gitee
commit 69202935e2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 8 additions and 7 deletions

View File

@ -445,9 +445,9 @@ bool ParseAction(const ResourcePtr &resource) {
bool CombineLikeGraphs(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto &obj_map = parse::data_converter::GetObjGraphs();
for (auto it : obj_map) {
auto &graphs = it.second;
MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
for (auto it = obj_map.rbegin(); it != obj_map.rend(); ++it) {
auto &graphs = it->second;
MS_LOG(DEBUG) << "Start combine like graph:" << it->first << ", size:" << graphs.size();
auto fg = graphs[0];
FuncGraphVector func_graphs = {fg};
Cloner cloner(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
@ -496,7 +496,7 @@ bool CombineLikeGraphs(const ResourcePtr &resource) {
const int recursive_level = 4;
MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(recursive_level);
}
MS_LOG(DEBUG) << "End combine graph:" << it.first;
MS_LOG(DEBUG) << "End combine graph:" << it->first;
}
return true;
}

View File

@ -611,14 +611,14 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
namespace data_converter {
static mindspore::HashMap<std::string, ValuePtr> object_map_;
static mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
static mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
object_graphs_map_[obj_key].push_back(data);
MS_LOG(DEBUG) << "Set func graph size: " << object_graphs_map_.size();
}
const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
const mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
MS_LOG(DEBUG) << "Obj graphs size: " << object_graphs_map_.size();
return object_graphs_map_;
}

View File

@ -23,6 +23,7 @@
#include <memory>
#include <vector>
#include <string>
#include "utils/ordered_map.h"
#include "utils/hash_map.h"
#include "pipeline/jit/parse/parse_base.h"
#include "include/common/utils/python_adapter.h"
@ -37,7 +38,7 @@ bool GetObjectValue(const std::string &obj_key, ValuePtr *const data);
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data);
const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs();
const mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs();
std::vector<std::string> GetObjKey(const py::object &obj);
ResolveTypeDef GetObjType(const py::object &obj);