add parameter eliminate pass

This commit is contained in:
chenfei 2021-06-04 11:52:02 +08:00
parent 643a25e03b
commit 4f50b3dfe1
6 changed files with 175 additions and 72 deletions

View File

@ -86,7 +86,7 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
void ExpandJPrim::GetJPrim(const FuncGraphManagerPtr &manager) {
j_nodes_.clear();
for (auto &fg : manager->func_graphs()) {
std::vector<AnfNodePtr> &&toposet = TopoSort(fg->get_return());
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
for (const auto &node : toposet) {
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
j_nodes_.push_back(node->cast<CNodePtr>());

View File

@ -0,0 +1,133 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
#include <vector>
#include <utility>
#include <unordered_set>
#include <memory>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/manager.h"
#include "ir/func_graph.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
class ParameterEliminator {
public:
ParameterEliminator() = default;
virtual ~ParameterEliminator() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
const auto &func_graph_callers = SearchFuncGraphCallers(func_graph);
const auto &manager = func_graph->manager();
auto tr = manager->Transact();
bool change = false;
for (const auto &fg_and_caller : func_graph_callers) {
const auto &fg = fg_and_caller.first;
const auto &erase_indexes = EraseUnusedParameters(fg, &tr);
// If no parameter unused, do nothing.
if (erase_indexes.empty()) {
continue;
}
// Erase the corresponding args.
change = true;
for (const auto &caller : fg_and_caller.second) {
EraseArgs(caller, erase_indexes, &tr);
}
}
tr.Commit();
return change;
}
private:
static OrderedMap<FuncGraphPtr, std::vector<CNodePtr>> SearchFuncGraphCallers(const FuncGraphPtr &func_graph) {
OrderedMap<FuncGraphPtr, std::vector<CNodePtr>> func_graph_callers;
for (const auto &fg : func_graph->func_graphs_used_total()) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
continue;
}
const auto &fg_caller_and_indexes = fg->func_graph_cnodes_index();
std::vector<CNodePtr> caller_cnodes = {};
// Find all caller of fg.
for (const auto &it : fg_caller_and_indexes) {
const auto &fg_caller_and_index = it.first;
auto caller_cnode = fg_caller_and_index->first;
auto index = fg_caller_and_index->second;
// If index != 0, the caller is a indirect caller, can't erase the parameter of graph.
if (index != 0) {
caller_cnodes.clear();
break;
}
caller_cnodes.push_back(caller_cnode->cast<CNodePtr>());
}
if (!caller_cnodes.empty()) {
func_graph_callers[fg] = caller_cnodes;
}
}
return func_graph_callers;
}
static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, FuncGraphTransaction *tr) {
const auto &manager_node_users = fg->manager()->node_users();
const auto &parameters = fg->parameters();
std::unordered_set<size_t> unused_parameter_indexes;
// Traverse to find all unused parameters.
size_t index = 0;
for (const auto &parameter : parameters) {
const auto &node_users_it = manager_node_users.find(parameter);
if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
unused_parameter_indexes.insert(index);
}
index++;
}
// Erase unused parameters.
std::vector<AnfNodePtr> new_parameters;
for (size_t i = 0; i < parameters.size(); i++) {
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
new_parameters.push_back(parameters[i]);
} else {
MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
}
}
tr->SetParameters(fg, new_parameters);
return unused_parameter_indexes;
}
static void EraseArgs(const CNodePtr &caller, const std::unordered_set<size_t> &unused_parameter_indexes,
FuncGraphTransaction *tr) {
std::vector<AnfNodePtr> new_args = {caller->inputs()[0]};
for (size_t i = 0; i < caller->inputs().size() - 1; i++) {
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
new_args.push_back(caller->inputs()[i + 1]);
} else {
MS_LOG(DEBUG) << "Erase arg:" << caller->inputs()[i + 1]->DebugString() << ",index:" << i;
}
}
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
auto new_caller = caller->func_graph()->NewCNode(new_args);
tr->Replace(caller, new_caller);
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H

View File

@ -41,45 +41,29 @@ class SpecializeTransform {
SpecializeTransform() : cache_() {}
~SpecializeTransform() = default;
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> tensor_value_args) {
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, const std::vector<ValuePtr> &need_eliminate_args) {
if (cache_.count(func_graph) == 0) {
cache_[func_graph] = {};
}
auto &cache = cache_[func_graph];
auto key = std::make_tuple(graph_args, prim_args, tensor_value_args);
const auto &key = need_eliminate_args;
if (cache.count(key) == 0) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
std::vector<AnfNodePtr> params = new_fg->parameters();
std::vector<AnfNodePtr> new_params;
size_t n = graph_args.size();
size_t n = need_eliminate_args.size();
for (size_t i = 0; i < n; i++) {
if (graph_args[i] != nullptr) {
auto arg = NewValueNode(graph_args[i]);
(void)mng->Replace(params[i], arg);
// keep the parameter
if (need_eliminate_args[i] == nullptr) {
new_params.push_back(params[i]);
continue;
}
if (prim_args[i] != nullptr) {
auto arg = NewValueNode(prim_args[i]);
(void)mng->Replace(params[i], arg);
continue;
}
if (tensor_value_args[i] != nullptr) {
auto &const_tensor = *tensor_value_args[i];
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
(void)mng->Replace(params[i], arg);
continue;
}
new_params.push_back(params[i]);
// replace the parameter with arg.
mng->Replace(params[i], NewReplaceValueNode(need_eliminate_args[i]));
}
mng->SetParameters(new_fg, new_params);
cache[key] = new_fg;
}
@ -87,11 +71,19 @@ class SpecializeTransform {
}
private:
std::unordered_map<
FuncGraphPtr,
std::map<std::tuple<std::vector<FuncGraphPtr>, std::vector<PrimitivePtr>, std::vector<tensor::TensorPtr>>,
FuncGraphPtr>>
cache_;
std::unordered_map<FuncGraphPtr, std::map<std::vector<ValuePtr>, FuncGraphPtr>> cache_;
static ValueNodePtr NewReplaceValueNode(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<FuncGraph>() || value->isa<Primitive>() || value->isa<parse::NameSpace>()) {
return NewValueNode(value);
}
if (value->isa<tensor::Tensor>()) {
auto &const_tensor = *(value->cast<tensor::TensorPtr>());
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
return NewValueNode(const_tensor_ptr);
}
MS_LOG(EXCEPTION) << "Unexpected value:" << value->ToString();
}
};
} // namespace internal
@ -115,44 +107,23 @@ class SpecializeOnGraphArguments : public AnfVisitor {
if (inp0_fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || inp0_fg->recursive()) {
return nullptr;
}
std::vector<FuncGraphPtr> graph_args;
std::vector<PrimitivePtr> prim_args;
std::vector<tensor::TensorPtr> tensor_value_args;
std::vector<ValuePtr> need_eliminated_args;
std::vector<AnfNodePtr> new_xs;
bool hasVNode = false;
for (size_t i = 1; i < inputs.size(); i++) {
if (IsValueNode<FuncGraph>(inputs[i])) {
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
graph_args.push_back(fg_vnode);
prim_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<Primitive>(inputs[i])) {
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.push_back(p_vnode);
tensor_value_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
tensor_value_args.emplace_back(t_vnode);
if (IsValueNode<FuncGraph>(inputs[i]) || IsValueNode<Primitive>(inputs[i]) ||
IsValueNode<tensor::Tensor>(inputs[i]) || IsValueNode<parse::NameSpace>(inputs[i])) {
need_eliminated_args.push_back(GetValueNode(inputs[i]));
hasVNode = true;
} else {
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
tensor_value_args.emplace_back(nullptr);
need_eliminated_args.emplace_back(nullptr);
new_xs.push_back(inputs[i]);
}
}
if (!hasVNode) {
return nullptr;
}
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, tensor_value_args);
auto new_fg = specialize_transform_(inp0_fg, need_eliminated_args);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs);

View File

@ -44,6 +44,7 @@
#include "pipeline/pynative/pynative_execute.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#if (ENABLE_CPU && !_WIN32)
#include "ps/util.h"
#include "ps/ps_context.h"
@ -227,8 +228,8 @@ void AddParallelRenormalize(OptPassGroupMap *map_a) {
}
}
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = opt::OptPassConfig({
opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
return opt::OptPassConfig({
irpass.switch_defer_inline_,
irpass.switch_layer_defer_inline_,
irpass.switch_simplify_,
@ -277,6 +278,10 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.stopgrad_eliminater_,
irpass.sparse_tensor_eliminate_,
});
}
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = GetOptPassA1(irpass);
opt::OptPassConfig a_2 = opt::OptPassConfig(
{
irpass.specialize_transform_,
@ -293,6 +298,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.all_reduce_const_elim_,
},
false, true);
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_,
});
@ -321,6 +327,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
OptPassGroupMap map_a({{"a_1", a_1},
{"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
{"a_2", a_2},
{"accelerated_algorithm", accelerated_algorithm},
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
@ -342,7 +349,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
auto opt_a = GetOptPassesA(irpass);
OptPassGroupMap a1_a2({opt_a[0], opt_a[1]});
OptPassGroupMap a1_a2({opt_a[0], opt_a[1], opt_a[2]});
return a1_a2;
}

View File

@ -252,7 +252,9 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<AbstractScalar>()) {
// Only broaden scalar that data type is number, such as float16,int32 and so on.
auto type = arg->BuildType()->type_id();
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
return arg->Broaden(config);
} else if (arg->GetValueTrack() != kAnyValue) {

View File

@ -80,8 +80,7 @@ CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) {
break;
}
case 1: {
abstract1->set_shape(nullptr);
param1->set_abstract(abstract1);
// Don't set abstract of param1, expecting a exception raised.
param2->set_abstract(abstract2);
break;
}
@ -274,15 +273,6 @@ TEST_F(TestStepParallel, ExtractShape3) {
ASSERT_EQ(shape_test, shape_expect);
}
TEST_F(TestStepParallel, ExtractShape4) {
Shape inputs_x_dims = {64, 32};
Shape inputs_y_dims = {32, 64};
Shape outputs_dims = {64, 64};
CNodePtr node = Make_Node(inputs_x_dims, inputs_y_dims, outputs_dims, 2);
Shapes inputs_shape = std::vector<Shape>{inputs_x_dims, inputs_y_dims};
EXPECT_THROW({ ExtractShape(node); }, std::runtime_error);
}
TEST_F(TestStepParallel, CreatOpInstance) {
ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
ValuePtr attr1_value = MakeValue("0-1-2");