!40418 Handling bug in CombineLikeGraphs.
Merge pull request !40418 from Margaret_wangrui/batchnorm_combine_like_graphs
This commit is contained in:
commit
69202935e2
|
@ -445,9 +445,9 @@ bool ParseAction(const ResourcePtr &resource) {
|
|||
bool CombineLikeGraphs(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto &obj_map = parse::data_converter::GetObjGraphs();
|
||||
for (auto it : obj_map) {
|
||||
auto &graphs = it.second;
|
||||
MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
|
||||
for (auto it = obj_map.rbegin(); it != obj_map.rend(); ++it) {
|
||||
auto &graphs = it->second;
|
||||
MS_LOG(DEBUG) << "Start combine like graph:" << it->first << ", size:" << graphs.size();
|
||||
auto fg = graphs[0];
|
||||
FuncGraphVector func_graphs = {fg};
|
||||
Cloner cloner(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
|
||||
|
@ -496,7 +496,7 @@ bool CombineLikeGraphs(const ResourcePtr &resource) {
|
|||
const int recursive_level = 4;
|
||||
MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(recursive_level);
|
||||
}
|
||||
MS_LOG(DEBUG) << "End combine graph:" << it.first;
|
||||
MS_LOG(DEBUG) << "End combine graph:" << it->first;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -611,14 +611,14 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
|||
namespace data_converter {
|
||||
static mindspore::HashMap<std::string, ValuePtr> object_map_;
|
||||
|
||||
static mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
|
||||
static mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> object_graphs_map_;
|
||||
|
||||
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) {
|
||||
object_graphs_map_[obj_key].push_back(data);
|
||||
MS_LOG(DEBUG) << "Set func graph size: " << object_graphs_map_.size();
|
||||
}
|
||||
|
||||
const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
|
||||
const mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs() {
|
||||
MS_LOG(DEBUG) << "Obj graphs size: " << object_graphs_map_.size();
|
||||
return object_graphs_map_;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "utils/ordered_map.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
|
@ -37,7 +38,7 @@ bool GetObjectValue(const std::string &obj_key, ValuePtr *const data);
|
|||
|
||||
void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data);
|
||||
|
||||
const mindspore::HashMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs();
|
||||
const mindspore::OrderedMap<std::string, std::vector<FuncGraphPtr>> &GetObjGraphs();
|
||||
|
||||
std::vector<std::string> GetObjKey(const py::object &obj);
|
||||
ResolveTypeDef GetObjType(const py::object &obj);
|
||||
|
|
Loading…
Reference in New Issue