forked from mindspore-Ecosystem/mindspore
add parameter eliminate pass
This commit is contained in:
parent
643a25e03b
commit
4f50b3dfe1
|
@ -86,7 +86,7 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
|
|||
void ExpandJPrim::GetJPrim(const FuncGraphManagerPtr &manager) {
|
||||
j_nodes_.clear();
|
||||
for (auto &fg : manager->func_graphs()) {
|
||||
std::vector<AnfNodePtr> &&toposet = TopoSort(fg->get_return());
|
||||
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
|
||||
for (const auto &node : toposet) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||
j_nodes_.push_back(node->cast<CNodePtr>());
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <unordered_set>
|
||||
#include <memory>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
class ParameterEliminator {
|
||||
public:
|
||||
ParameterEliminator() = default;
|
||||
virtual ~ParameterEliminator() = default;
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||
const auto &func_graph_callers = SearchFuncGraphCallers(func_graph);
|
||||
const auto &manager = func_graph->manager();
|
||||
auto tr = manager->Transact();
|
||||
bool change = false;
|
||||
for (const auto &fg_and_caller : func_graph_callers) {
|
||||
const auto &fg = fg_and_caller.first;
|
||||
const auto &erase_indexes = EraseUnusedParameters(fg, &tr);
|
||||
// If no parameter unused, do nothing.
|
||||
if (erase_indexes.empty()) {
|
||||
continue;
|
||||
}
|
||||
// Erase the corresponding args.
|
||||
change = true;
|
||||
for (const auto &caller : fg_and_caller.second) {
|
||||
EraseArgs(caller, erase_indexes, &tr);
|
||||
}
|
||||
}
|
||||
tr.Commit();
|
||||
return change;
|
||||
}
|
||||
|
||||
private:
|
||||
static OrderedMap<FuncGraphPtr, std::vector<CNodePtr>> SearchFuncGraphCallers(const FuncGraphPtr &func_graph) {
|
||||
OrderedMap<FuncGraphPtr, std::vector<CNodePtr>> func_graph_callers;
|
||||
for (const auto &fg : func_graph->func_graphs_used_total()) {
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
||||
continue;
|
||||
}
|
||||
const auto &fg_caller_and_indexes = fg->func_graph_cnodes_index();
|
||||
std::vector<CNodePtr> caller_cnodes = {};
|
||||
// Find all caller of fg.
|
||||
for (const auto &it : fg_caller_and_indexes) {
|
||||
const auto &fg_caller_and_index = it.first;
|
||||
auto caller_cnode = fg_caller_and_index->first;
|
||||
auto index = fg_caller_and_index->second;
|
||||
// If index != 0, the caller is a indirect caller, can't erase the parameter of graph.
|
||||
if (index != 0) {
|
||||
caller_cnodes.clear();
|
||||
break;
|
||||
}
|
||||
caller_cnodes.push_back(caller_cnode->cast<CNodePtr>());
|
||||
}
|
||||
if (!caller_cnodes.empty()) {
|
||||
func_graph_callers[fg] = caller_cnodes;
|
||||
}
|
||||
}
|
||||
return func_graph_callers;
|
||||
}
|
||||
|
||||
static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, FuncGraphTransaction *tr) {
|
||||
const auto &manager_node_users = fg->manager()->node_users();
|
||||
const auto ¶meters = fg->parameters();
|
||||
std::unordered_set<size_t> unused_parameter_indexes;
|
||||
// Traverse to find all unused parameters.
|
||||
size_t index = 0;
|
||||
for (const auto ¶meter : parameters) {
|
||||
const auto &node_users_it = manager_node_users.find(parameter);
|
||||
if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
|
||||
unused_parameter_indexes.insert(index);
|
||||
}
|
||||
index++;
|
||||
}
|
||||
// Erase unused parameters.
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
for (size_t i = 0; i < parameters.size(); i++) {
|
||||
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
|
||||
new_parameters.push_back(parameters[i]);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
|
||||
}
|
||||
}
|
||||
tr->SetParameters(fg, new_parameters);
|
||||
return unused_parameter_indexes;
|
||||
}
|
||||
|
||||
static void EraseArgs(const CNodePtr &caller, const std::unordered_set<size_t> &unused_parameter_indexes,
|
||||
FuncGraphTransaction *tr) {
|
||||
std::vector<AnfNodePtr> new_args = {caller->inputs()[0]};
|
||||
for (size_t i = 0; i < caller->inputs().size() - 1; i++) {
|
||||
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
|
||||
new_args.push_back(caller->inputs()[i + 1]);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Erase arg:" << caller->inputs()[i + 1]->DebugString() << ",index:" << i;
|
||||
}
|
||||
}
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
|
||||
auto new_caller = caller->func_graph()->NewCNode(new_args);
|
||||
tr->Replace(caller, new_caller);
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
|
|
@ -41,45 +41,29 @@ class SpecializeTransform {
|
|||
SpecializeTransform() : cache_() {}
|
||||
~SpecializeTransform() = default;
|
||||
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
|
||||
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> tensor_value_args) {
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, const std::vector<ValuePtr> &need_eliminate_args) {
|
||||
if (cache_.count(func_graph) == 0) {
|
||||
cache_[func_graph] = {};
|
||||
}
|
||||
|
||||
auto &cache = cache_[func_graph];
|
||||
auto key = std::make_tuple(graph_args, prim_args, tensor_value_args);
|
||||
const auto &key = need_eliminate_args;
|
||||
if (cache.count(key) == 0) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
|
||||
FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared<TraceTransform>("sp"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
|
||||
std::vector<AnfNodePtr> params = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
size_t n = graph_args.size();
|
||||
size_t n = need_eliminate_args.size();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
if (graph_args[i] != nullptr) {
|
||||
auto arg = NewValueNode(graph_args[i]);
|
||||
(void)mng->Replace(params[i], arg);
|
||||
// keep the parameter
|
||||
if (need_eliminate_args[i] == nullptr) {
|
||||
new_params.push_back(params[i]);
|
||||
continue;
|
||||
}
|
||||
if (prim_args[i] != nullptr) {
|
||||
auto arg = NewValueNode(prim_args[i]);
|
||||
(void)mng->Replace(params[i], arg);
|
||||
continue;
|
||||
}
|
||||
if (tensor_value_args[i] != nullptr) {
|
||||
auto &const_tensor = *tensor_value_args[i];
|
||||
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
|
||||
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
|
||||
(void)mng->Replace(params[i], arg);
|
||||
continue;
|
||||
}
|
||||
new_params.push_back(params[i]);
|
||||
// replace the parameter with arg.
|
||||
mng->Replace(params[i], NewReplaceValueNode(need_eliminate_args[i]));
|
||||
}
|
||||
|
||||
mng->SetParameters(new_fg, new_params);
|
||||
cache[key] = new_fg;
|
||||
}
|
||||
|
@ -87,11 +71,19 @@ class SpecializeTransform {
|
|||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<
|
||||
FuncGraphPtr,
|
||||
std::map<std::tuple<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>, std::vector<tensor::TensorPtr>>,
|
||||
FuncGraphPtr>>
|
||||
cache_;
|
||||
std::unordered_map<FuncGraphPtr, std::map<std::vector<ValuePtr>, FuncGraphPtr>> cache_;
|
||||
static ValueNodePtr NewReplaceValueNode(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<FuncGraph>() || value->isa<Primitive>() || value->isa<parse::NameSpace>()) {
|
||||
return NewValueNode(value);
|
||||
}
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto &const_tensor = *(value->cast<tensor::TensorPtr>());
|
||||
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
|
||||
return NewValueNode(const_tensor_ptr);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected value:" << value->ToString();
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
|
@ -115,44 +107,23 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|||
if (inp0_fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || inp0_fg->recursive()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<FuncGraphPtr> graph_args;
|
||||
std::vector<PrimitivePtr> prim_args;
|
||||
std::vector<tensor::TensorPtr> tensor_value_args;
|
||||
std::vector<ValuePtr> need_eliminated_args;
|
||||
std::vector<AnfNodePtr> new_xs;
|
||||
bool hasVNode = false;
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (IsValueNode<FuncGraph>(inputs[i])) {
|
||||
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
|
||||
graph_args.push_back(fg_vnode);
|
||||
prim_args.emplace_back(nullptr);
|
||||
tensor_value_args.emplace_back(nullptr);
|
||||
hasVNode = true;
|
||||
} else if (IsValueNode<Primitive>(inputs[i])) {
|
||||
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
|
||||
graph_args.emplace_back(nullptr);
|
||||
prim_args.push_back(p_vnode);
|
||||
tensor_value_args.emplace_back(nullptr);
|
||||
hasVNode = true;
|
||||
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
|
||||
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
|
||||
graph_args.emplace_back(nullptr);
|
||||
prim_args.emplace_back(nullptr);
|
||||
tensor_value_args.emplace_back(t_vnode);
|
||||
if (IsValueNode<FuncGraph>(inputs[i]) || IsValueNode<Primitive>(inputs[i]) ||
|
||||
IsValueNode<tensor::Tensor>(inputs[i]) || IsValueNode<parse::NameSpace>(inputs[i])) {
|
||||
need_eliminated_args.push_back(GetValueNode(inputs[i]));
|
||||
hasVNode = true;
|
||||
} else {
|
||||
graph_args.emplace_back(nullptr);
|
||||
prim_args.emplace_back(nullptr);
|
||||
tensor_value_args.emplace_back(nullptr);
|
||||
need_eliminated_args.emplace_back(nullptr);
|
||||
new_xs.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasVNode) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, tensor_value_args);
|
||||
auto new_fg = specialize_transform_(inp0_fg, need_eliminated_args);
|
||||
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
|
||||
|
||||
return node->func_graph()->NewCNode(new_xs);
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/parameter_eliminate.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
@ -227,8 +228,8 @@ void AddParallelRenormalize(OptPassGroupMap *map_a) {
|
|||
}
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
||||
opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
return opt::OptPassConfig({
|
||||
irpass.switch_defer_inline_,
|
||||
irpass.switch_layer_defer_inline_,
|
||||
irpass.switch_simplify_,
|
||||
|
@ -277,6 +278,10 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.stopgrad_eliminater_,
|
||||
irpass.sparse_tensor_eliminate_,
|
||||
});
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig a_1 = GetOptPassA1(irpass);
|
||||
opt::OptPassConfig a_2 = opt::OptPassConfig(
|
||||
{
|
||||
irpass.specialize_transform_,
|
||||
|
@ -293,6 +298,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.all_reduce_const_elim_,
|
||||
},
|
||||
false, true);
|
||||
|
||||
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
|
||||
irpass.inline_without_move_,
|
||||
});
|
||||
|
@ -321,6 +327,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
||||
OptPassGroupMap map_a({{"a_1", a_1},
|
||||
{"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
|
||||
{"a_2", a_2},
|
||||
{"accelerated_algorithm", accelerated_algorithm},
|
||||
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
|
||||
|
@ -342,7 +349,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
auto opt_a = GetOptPassesA(irpass);
|
||||
OptPassGroupMap a1_a2({opt_a[0], opt_a[1]});
|
||||
OptPassGroupMap a1_a2({opt_a[0], opt_a[1], opt_a[2]});
|
||||
return a1_a2;
|
||||
}
|
||||
|
||||
|
|
|
@ -252,7 +252,9 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
|
|||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->isa<AbstractScalar>()) {
|
||||
// Only broaden scalar that data type is number, such as float16,int32 and so on.
|
||||
auto type = arg->BuildType()->type_id();
|
||||
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
|
||||
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
|
||||
return arg->Broaden(config);
|
||||
} else if (arg->GetValueTrack() != kAnyValue) {
|
||||
|
|
|
@ -80,8 +80,7 @@ CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) {
|
|||
break;
|
||||
}
|
||||
case 1: {
|
||||
abstract1->set_shape(nullptr);
|
||||
param1->set_abstract(abstract1);
|
||||
// Don't set abstract of param1, expecting a exception raised.
|
||||
param2->set_abstract(abstract2);
|
||||
break;
|
||||
}
|
||||
|
@ -274,15 +273,6 @@ TEST_F(TestStepParallel, ExtractShape3) {
|
|||
ASSERT_EQ(shape_test, shape_expect);
|
||||
}
|
||||
|
||||
TEST_F(TestStepParallel, ExtractShape4) {
|
||||
Shape inputs_x_dims = {64, 32};
|
||||
Shape inputs_y_dims = {32, 64};
|
||||
Shape outputs_dims = {64, 64};
|
||||
CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 2);
|
||||
Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
|
||||
EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
|
||||
}
|
||||
|
||||
TEST_F(TestStepParallel, CreatOpInstance) {
|
||||
ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
|
||||
ValuePtr attr1_value = MakeValue("0-1-2");
|
||||
|
|
Loading…
Reference in New Issue