forked from OSSInnovation/mindspore
insert u monad parameter before io monad parameter in auto_monad
This commit is contained in:
parent
0d1d043d80
commit
86b52c4107
|
@ -41,8 +41,11 @@ using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
|
||||||
// Add or get a monad parameter.
|
// Add or get a monad parameter.
|
||||||
AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
|
AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
|
||||||
const abstract::AbstractBasePtr &abs) {
|
const abstract::AbstractBasePtr &abs) {
|
||||||
|
size_t params_size = func_graph->parameters().size();
|
||||||
|
size_t io_monad_location = params_size;
|
||||||
// Search for existed parameters, return it if found.
|
// Search for existed parameters, return it if found.
|
||||||
for (auto &node : func_graph->parameters()) {
|
for (size_t i = 0; i < params_size; i++) {
|
||||||
|
auto &node = func_graph->parameters()[i];
|
||||||
auto para = dyn_cast<Parameter>(node);
|
auto para = dyn_cast<Parameter>(node);
|
||||||
if (para == nullptr) {
|
if (para == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -51,13 +54,23 @@ AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &
|
||||||
if (para_abs && *para_abs == *abs) {
|
if (para_abs && *para_abs == *abs) {
|
||||||
return para;
|
return para;
|
||||||
}
|
}
|
||||||
|
if (HasAbstractIOMonad(para)) {
|
||||||
|
io_monad_location = i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Create a new parameter if not existed.
|
// Create a new parameter if not existed.
|
||||||
auto para = std::make_shared<Parameter>(func_graph);
|
auto para = std::make_shared<Parameter>(func_graph);
|
||||||
para->set_name(name);
|
para->set_name(name);
|
||||||
para->debug_info()->set_name(name);
|
para->debug_info()->set_name(name);
|
||||||
para->set_abstract(abs);
|
para->set_abstract(abs);
|
||||||
|
// If io monad parameter added before u monad parameter, should insert u monad before io monad in parameters
|
||||||
|
if (io_monad_location != params_size && abs->isa<abstract::AbstractUMonad>()) {
|
||||||
|
std::vector<AnfNodePtr> params = func_graph->parameters();
|
||||||
|
params.insert(params.begin() + io_monad_location, para);
|
||||||
|
func_graph->set_parameters(params);
|
||||||
|
} else {
|
||||||
func_graph->add_parameter(para);
|
func_graph->add_parameter(para);
|
||||||
|
}
|
||||||
return para;
|
return para;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue