Refactor_part_of_pipeline

This commit is contained in:
lichenever 2021-08-20 11:03:03 +08:00
parent 6c7ffb7f1d
commit 5812076512
7 changed files with 343 additions and 68 deletions

View File

@ -17,6 +17,7 @@
#include <iterator>
#include <memory>
#include <list>
#include <set>
#include <algorithm>
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "frontend/parallel/graph_util/generate_graph.h"
@ -502,6 +503,14 @@ void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
}
}
AnfNodePtr GetActualOp(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
auto cnode = node->cast<CNodePtr>();
return cnode->input(1);
}
return node;
}
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
@ -611,11 +620,15 @@ void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
if (!IsLastStage()) {
for (auto &node : forward_end_pair.first) {
auto cnode = node->cast<CNodePtr>();
forward_end_before_pair.first.push_back(cnode->input(1));
auto temp_node = GetActualOp(cnode->input(1));
MS_EXCEPTION_IF_NULL(temp_node);
forward_end_before_pair.first.push_back(temp_node);
}
for (auto &node : forward_end_pair.second) {
auto cnode = node->cast<CNodePtr>();
forward_end_before_pair.second.push_back(cnode->input(1));
auto temp_node = GetActualOp(cnode->input(1));
MS_EXCEPTION_IF_NULL(temp_node);
forward_end_before_pair.second.push_back(temp_node);
}
} else {
forward_end_before_pair = forward_end_pair;

View File

@ -48,6 +48,7 @@ int64_t GetMicroBatch(const AnfNodePtr &node);
void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
const FuncGraphPtr &root);
PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max);
AnfNodePtr GetActualOp(const AnfNodePtr &node);
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,

View File

@ -336,58 +336,63 @@ std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr
return std::make_pair(op_info, tensor_info_index);
}
AnfNodeIndexSet PipelineTransformer::GetActualOpUsers(const AnfNodePtr &node, NodeUsersMap *node_users_map) {
auto temp_users = (*node_users_map)[node];
auto temp_node = temp_users.front().first;
if (IsPrimitiveCNode(temp_node, prim::kPrimLoad) || IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
return GetActualOpUsers(temp_node, node_users_map);
}
return temp_users;
}
std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto node_users_map = manager_->node_users();
auto node_users = node_users_map[node];
for (auto &node_user : node_users) {
auto load = node_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(load, prim::kPrimLoad)) {
node_users = node_users_map[load];
break;
}
}
for (auto &user_pair : node_users) {
auto user_node = user_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_node);
auto user_node_graph = user_node->func_graph();
MS_EXCEPTION_IF_NULL(user_node_graph);
if (user_node_graph->stage() == -1) {
continue;
}
auto care_node = user_node;
auto index = user_pair.second;
if (IsValueNode<FuncGraph>(user_node->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0));
auto temp_params = graph->parameters();
if (temp_params.size() < IntToSize(user_pair.second)) {
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString()
<< "'s range.";
auto load_users = GetActualOpUsers(node_user.first, &node_users_map);
for (auto &user_pair : load_users) {
auto user_node = user_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_node);
auto user_node_graph = user_node->func_graph();
MS_EXCEPTION_IF_NULL(user_node_graph);
if (user_node_graph->stage() == -1) {
continue;
}
auto temp_param = temp_params[user_pair.second - 1];
auto temp_users = node_users_map[temp_param];
for (auto &temp_user : temp_users) {
auto load_temp = temp_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) {
temp_users = node_users_map[load_temp];
auto care_node = user_node;
auto index = user_pair.second;
if (IsValueNode<FuncGraph>(user_node->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0));
auto temp_params = graph->parameters();
if (temp_params.size() < IntToSize(user_pair.second)) {
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString()
<< "'s range.";
}
auto temp_param = temp_params[user_pair.second - 1];
auto temp_users = node_users_map[temp_param];
for (auto &temp_user : temp_users) {
auto load_temp = temp_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) {
temp_users = node_users_map[load_temp];
break;
}
}
for (auto &temp_pair : temp_users) {
auto temp_cnode = temp_pair.first->cast<CNodePtr>();
if (!IsPipelineCareNode(temp_cnode)) {
continue;
}
care_node = temp_cnode;
index = temp_pair.second;
break;
}
}
for (auto &temp_pair : temp_users) {
auto temp_cnode = temp_pair.first->cast<CNodePtr>();
if (!IsPipelineCareNode(temp_cnode)) {
continue;
}
care_node = temp_cnode;
index = temp_pair.second;
break;
if (!IsPipelineCareNode(care_node)) {
continue;
}
auto op_info = CreateOpInfo(care_node);
return std::make_pair(op_info, index - 1);
}
if (!IsPipelineCareNode(care_node)) {
continue;
}
auto op_info = CreateOpInfo(care_node);
return std::make_pair(op_info, index - 1);
}
return std::make_pair(nullptr, 0);
}
@ -432,7 +437,8 @@ std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
if (receive) {
manager_->SetEdge(node, user.second, receive);
} else {
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro);
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro,
parameter);
recvs.push_back(recv);
}
}
@ -594,7 +600,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
const AnfNodePtr &use_node, int index, int64_t user_node_stage,
int64_t node_stage, const ValuePtr &value) {
int64_t node_stage, const ValuePtr &value,
const AnfNodePtr &graph_param) {
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
int64_t recv_tag;
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
@ -610,18 +617,13 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
bool is_param = true;
TensorInfo tensor_info;
if (node->isa<Parameter>()) {
op_info_pair = GetParameterPair(node);
op_info_pair = GetParameterPair(graph_param);
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
} else {
auto care_node = FindPipelineCareNode(node);
if (care_node->isa<Parameter>()) {
op_info_pair = GetParameterPair(care_node);
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
} else {
op_info_pair = GetOpInfo(care_node);
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
is_param = false;
}
op_info_pair = GetOpInfo(care_node);
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
is_param = false;
}
auto tensor_layout = tensor_info.tensor_layout();
Shape slice_shape = tensor_info.slice_shape();
@ -694,6 +696,119 @@ AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, con
return nullptr;
}
AnfNodePtr PipelineTransformer::ActualOp(const AnfNodePtr &node) {
// skip some virtual op like:Depend, Load, Cast.
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimCast) ||
IsPrimitiveCNode(node, prim::kPrimLoad)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return ActualOp(cnode->input(1));
}
return node;
}
bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) {
// ParameterGraph: graph which return a parameter
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// parameter_graph->return->load->graph
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
auto graph_cnode = cnode->input(1)->cast<CNodePtr>();
if (!graph_cnode) {
return false;
}
if (!IsValueNode<FuncGraph>(graph_cnode->input(0))) {
return false;
}
// Now load's input must be a parameter
return true;
}
// parameter_graph->return->graph
if (!IsValueNode<FuncGraph>(cnode->input(0))) {
return false;
}
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(graph);
auto graph_out = graph->output();
MS_EXCEPTION_IF_NULL(graph_out);
auto actual_op = ActualOp(graph_out);
MS_EXCEPTION_IF_NULL(actual_op);
if (actual_op->isa<Parameter>()) {
auto parameter_list = graph->parameters();
// parameter_graph->parameter->return->graph
auto parameter_iter = std::find(parameter_list.begin(), parameter_list.end(), actual_op);
if (parameter_iter == parameter_list.end()) {
return true;
}
// parameter->graph->return->graph
auto pos = std::distance(parameter_list.begin(), parameter_iter);
if (!cnode->input(pos + 1)->isa<Parameter>()) {
return false;
}
return true;
}
return false;
}
AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
int64_t user_stage, const ValuePtr &micro, size_t pos,
const std::vector<AnfNodePtr> ops) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr argument;
AnfNodePtr parameter;
FuncGraphPtr graph;
// parameter_graph->load->graph
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
auto graph_cnode = cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(graph_cnode);
graph = GetValueNode<FuncGraphPtr>(graph_cnode->input(0));
} else {
graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
}
MS_EXCEPTION_IF_NULL(graph);
auto graph_out = ActualOp(graph->output());
MS_EXCEPTION_IF_NULL(graph_out);
auto parameter_list = graph->parameters();
auto param_iter = std::find(parameter_list.begin(), parameter_list.end(), graph_out);
auto use_cnode = use_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(use_cnode);
if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
}
auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
auto use_parameter_list = use_graph->parameters();
parameter = use_parameter_list.at(pos - 1);
// argument->load->graph
if (param_iter == parameter_list.end()) {
argument = graph_out;
} else {
auto param_pos = std::distance(parameter_list.begin(), param_iter);
argument = cnode->input(param_pos + 1);
}
// insert receive
if (stage_ == user_stage) {
auto recv = Reuse(argument, stage, ops, SRC_RANK);
if (recv) {
manager_->SetEdge(use_node, pos, recv);
return nullptr;
}
return InsertReceive(main_graph_, argument, use_node, pos, user_stage, stage, micro, parameter);
}
// insert send
if (Reuse(argument, user_stage, ops, DEST_RANK)) {
return nullptr;
}
auto send_out = InsertSend(main_graph_, argument, user_stage, stage_, micro);
send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
return send_out.depend;
}
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
@ -708,7 +823,7 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
}
for (auto &node : all_nodes) {
if (!node->isa<CNode>() || node->stage() == -1) {
if (!node->isa<CNode>() || node->stage() == -1 || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
continue;
}
auto node_users = manager_->node_users()[node];
@ -727,6 +842,15 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
}
if (node_stage < user_node_stage) {
if (node_stage == stage_) {
if (IsParameterGraph(node)) {
auto send_depend =
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, send_ops);
if (!send_depend) {
continue;
}
send_ops.insert(send_ops.begin(), send_depend);
continue;
}
if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
continue;
}
@ -737,8 +861,18 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
} else {
if (!receive) {
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro);
receive_ops.push_back(receive);
if (IsParameterGraph(node)) {
receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second,
receive_ops);
if (!receive) {
continue;
}
receive_ops.push_back(receive);
} else {
receive =
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
receive_ops.push_back(receive);
}
} else {
manager_->SetEdge(user_node, user_pair.second, receive);
}

View File

@ -62,12 +62,18 @@ class PipelineTransformer {
private:
void CreateForwardGroup();
AnfNodePtr ActualOp(const AnfNodePtr &node);
bool IsParameterGraph(const AnfNodePtr &node);
AnfNodeIndexSet GetActualOpUsers(const AnfNodePtr &node, NodeUsersMap *node_users_map);
AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage,
const ValuePtr &micro, size_t pos, const std::vector<AnfNodePtr> ops);
ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size);
std::vector<AnfNodePtr> HandleSharedParameter();
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int64_t user_node_stage,
int64_t node_stage, const ValuePtr &value);
AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value);
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
const AnfNodePtr &graph_param);
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
const std::string &tag);

View File

@ -326,17 +326,7 @@ ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
}
}
if (py::hasattr(obj, STAGE_NAME)) {
auto obj_stage = py::getattr(obj, STAGE_NAME);
if (py::isinstance<py::bool_>(obj_stage)) {
MS_LOG(EXCEPTION) << "The type of pipeline stage must be int, but got bool.";
}
if (!py::isinstance<py::int_>(obj_stage)) {
MS_LOG(EXCEPTION) << "The type of pipeline stage must be int.";
}
auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
if (stage < 0) {
MS_LOG(EXCEPTION) << "Pipeline stage can't be less than 0, but got: " << stage;
}
func_graph->set_stage(stage);
}
return func_graph;

View File

@ -242,8 +242,12 @@ class Cell(Cell_):
@pipeline_stage.setter
def pipeline_stage(self, value):
if isinstance(value, bool):
raise TypeError("'pipeline_stage' must be int type, but got bool.")
if not isinstance(value, int):
raise TypeError("'pipeline_stage' must be int type.")
if value < 0:
raise TypeError("'pipeline_stage' can not less than 0.")
self._pipeline_stage = value
for item in self.trainable_params():
item.add_pipeline_stage(value)

View File

@ -0,0 +1,127 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.train.model import Model
from mindspore.nn.wrap.cell_wrapper import PipelineCell
class DatasetLenet():
def __init__(self, data, label, length=3):
self.data = data
self.label = label
self.index = 1
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.data, self.label
def reset(self):
self.index = 0
def get_dataset_size(self):
return 32
def get_repeat_count(self):
return 1
def get_batch_size(self):
return 32
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
return self
class MatMulCell(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
def construct(self, x):
out = self.matmul(x, self.param)
out = self.matmul1(out, self.param1)
return out, self.param
class MatMulCell2(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
def construct(self, x, param):
out = self.matmul(x, param)
out = self.matmul1(out, self.param1)
return out
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, param=None):
super().__init__()
self.cell1 = MatMulCell(strategy1, strategy2)
self.cell1.pipeline_stage = 0
self.cell2 = MatMulCell2(strategy1, strategy2)
self.cell2.pipeline_stage = 1
def construct(self, x, label):
out, param = self.cell1(x)
out = self.cell2(out, param)
return out
def test_pipeline_split_stage0():
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(Net(strategy1, strategy2), 4)
params = net.network.cell1.trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_stage1():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(Net(strategy1, strategy2), 4)
params = net.network.cell2.trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)