Refactor_part_of_pipeline
This commit is contained in:
parent
6c7ffb7f1d
commit
5812076512
|
@ -17,6 +17,7 @@
|
|||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
|
@ -502,6 +503,14 @@ void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr GetActualOp(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
return cnode->input(1);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
|
||||
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
|
||||
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
|
||||
|
@ -611,11 +620,15 @@ void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
|||
if (!IsLastStage()) {
|
||||
for (auto &node : forward_end_pair.first) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
forward_end_before_pair.first.push_back(cnode->input(1));
|
||||
auto temp_node = GetActualOp(cnode->input(1));
|
||||
MS_EXCEPTION_IF_NULL(temp_node);
|
||||
forward_end_before_pair.first.push_back(temp_node);
|
||||
}
|
||||
for (auto &node : forward_end_pair.second) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
forward_end_before_pair.second.push_back(cnode->input(1));
|
||||
auto temp_node = GetActualOp(cnode->input(1));
|
||||
MS_EXCEPTION_IF_NULL(temp_node);
|
||||
forward_end_before_pair.second.push_back(temp_node);
|
||||
}
|
||||
} else {
|
||||
forward_end_before_pair = forward_end_pair;
|
||||
|
|
|
@ -48,6 +48,7 @@ int64_t GetMicroBatch(const AnfNodePtr &node);
|
|||
void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
|
||||
const FuncGraphPtr &root);
|
||||
PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max);
|
||||
AnfNodePtr GetActualOp(const AnfNodePtr &node);
|
||||
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
|
||||
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
|
||||
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
|
||||
|
|
|
@ -336,58 +336,63 @@ std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr
|
|||
return std::make_pair(op_info, tensor_info_index);
|
||||
}
|
||||
|
||||
AnfNodeIndexSet PipelineTransformer::GetActualOpUsers(const AnfNodePtr &node, NodeUsersMap *node_users_map) {
|
||||
auto temp_users = (*node_users_map)[node];
|
||||
auto temp_node = temp_users.front().first;
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimLoad) || IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
|
||||
return GetActualOpUsers(temp_node, node_users_map);
|
||||
}
|
||||
return temp_users;
|
||||
}
|
||||
|
||||
std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_users_map = manager_->node_users();
|
||||
auto node_users = node_users_map[node];
|
||||
for (auto &node_user : node_users) {
|
||||
auto load = node_user.first->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(load, prim::kPrimLoad)) {
|
||||
node_users = node_users_map[load];
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (auto &user_pair : node_users) {
|
||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
auto user_node_graph = user_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(user_node_graph);
|
||||
if (user_node_graph->stage() == -1) {
|
||||
continue;
|
||||
}
|
||||
auto care_node = user_node;
|
||||
auto index = user_pair.second;
|
||||
if (IsValueNode<FuncGraph>(user_node->input(0))) {
|
||||
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0));
|
||||
auto temp_params = graph->parameters();
|
||||
if (temp_params.size() < IntToSize(user_pair.second)) {
|
||||
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString()
|
||||
<< "'s range.";
|
||||
auto load_users = GetActualOpUsers(node_user.first, &node_users_map);
|
||||
for (auto &user_pair : load_users) {
|
||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
auto user_node_graph = user_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(user_node_graph);
|
||||
if (user_node_graph->stage() == -1) {
|
||||
continue;
|
||||
}
|
||||
auto temp_param = temp_params[user_pair.second - 1];
|
||||
auto temp_users = node_users_map[temp_param];
|
||||
for (auto &temp_user : temp_users) {
|
||||
auto load_temp = temp_user.first->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) {
|
||||
temp_users = node_users_map[load_temp];
|
||||
auto care_node = user_node;
|
||||
auto index = user_pair.second;
|
||||
if (IsValueNode<FuncGraph>(user_node->input(0))) {
|
||||
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0));
|
||||
auto temp_params = graph->parameters();
|
||||
if (temp_params.size() < IntToSize(user_pair.second)) {
|
||||
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString()
|
||||
<< "'s range.";
|
||||
}
|
||||
auto temp_param = temp_params[user_pair.second - 1];
|
||||
auto temp_users = node_users_map[temp_param];
|
||||
for (auto &temp_user : temp_users) {
|
||||
auto load_temp = temp_user.first->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) {
|
||||
temp_users = node_users_map[load_temp];
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (auto &temp_pair : temp_users) {
|
||||
auto temp_cnode = temp_pair.first->cast<CNodePtr>();
|
||||
if (!IsPipelineCareNode(temp_cnode)) {
|
||||
continue;
|
||||
}
|
||||
care_node = temp_cnode;
|
||||
index = temp_pair.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (auto &temp_pair : temp_users) {
|
||||
auto temp_cnode = temp_pair.first->cast<CNodePtr>();
|
||||
if (!IsPipelineCareNode(temp_cnode)) {
|
||||
continue;
|
||||
}
|
||||
care_node = temp_cnode;
|
||||
index = temp_pair.second;
|
||||
break;
|
||||
if (!IsPipelineCareNode(care_node)) {
|
||||
continue;
|
||||
}
|
||||
auto op_info = CreateOpInfo(care_node);
|
||||
return std::make_pair(op_info, index - 1);
|
||||
}
|
||||
if (!IsPipelineCareNode(care_node)) {
|
||||
continue;
|
||||
}
|
||||
auto op_info = CreateOpInfo(care_node);
|
||||
return std::make_pair(op_info, index - 1);
|
||||
}
|
||||
return std::make_pair(nullptr, 0);
|
||||
}
|
||||
|
@ -432,7 +437,8 @@ std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
|
|||
if (receive) {
|
||||
manager_->SetEdge(node, user.second, receive);
|
||||
} else {
|
||||
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro);
|
||||
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro,
|
||||
parameter);
|
||||
recvs.push_back(recv);
|
||||
}
|
||||
}
|
||||
|
@ -594,7 +600,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|||
|
||||
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const AnfNodePtr &use_node, int index, int64_t user_node_stage,
|
||||
int64_t node_stage, const ValuePtr &value) {
|
||||
int64_t node_stage, const ValuePtr &value,
|
||||
const AnfNodePtr &graph_param) {
|
||||
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
|
||||
int64_t recv_tag;
|
||||
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
|
||||
|
@ -610,18 +617,13 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
|
|||
bool is_param = true;
|
||||
TensorInfo tensor_info;
|
||||
if (node->isa<Parameter>()) {
|
||||
op_info_pair = GetParameterPair(node);
|
||||
op_info_pair = GetParameterPair(graph_param);
|
||||
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
} else {
|
||||
auto care_node = FindPipelineCareNode(node);
|
||||
if (care_node->isa<Parameter>()) {
|
||||
op_info_pair = GetParameterPair(care_node);
|
||||
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
} else {
|
||||
op_info_pair = GetOpInfo(care_node);
|
||||
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
is_param = false;
|
||||
}
|
||||
op_info_pair = GetOpInfo(care_node);
|
||||
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
is_param = false;
|
||||
}
|
||||
auto tensor_layout = tensor_info.tensor_layout();
|
||||
Shape slice_shape = tensor_info.slice_shape();
|
||||
|
@ -694,6 +696,119 @@ AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, con
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr PipelineTransformer::ActualOp(const AnfNodePtr &node) {
|
||||
// skip some virtual op like:Depend, Load, Cast.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimCast) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return ActualOp(cnode->input(1));
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) {
|
||||
// ParameterGraph: graph which return a parameter
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// parameter_graph->return->load->graph
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
auto graph_cnode = cnode->input(1)->cast<CNodePtr>();
|
||||
if (!graph_cnode) {
|
||||
return false;
|
||||
}
|
||||
if (!IsValueNode<FuncGraph>(graph_cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
// Now load's input must be a parameter
|
||||
return true;
|
||||
}
|
||||
|
||||
// parameter_graph->return->graph
|
||||
if (!IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_out = graph->output();
|
||||
MS_EXCEPTION_IF_NULL(graph_out);
|
||||
auto actual_op = ActualOp(graph_out);
|
||||
MS_EXCEPTION_IF_NULL(actual_op);
|
||||
if (actual_op->isa<Parameter>()) {
|
||||
auto parameter_list = graph->parameters();
|
||||
// parameter_graph->parameter->return->graph
|
||||
auto parameter_iter = std::find(parameter_list.begin(), parameter_list.end(), actual_op);
|
||||
if (parameter_iter == parameter_list.end()) {
|
||||
return true;
|
||||
}
|
||||
// parameter->graph->return->graph
|
||||
auto pos = std::distance(parameter_list.begin(), parameter_iter);
|
||||
if (!cnode->input(pos + 1)->isa<Parameter>()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
|
||||
int64_t user_stage, const ValuePtr µ, size_t pos,
|
||||
const std::vector<AnfNodePtr> ops) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtr argument;
|
||||
AnfNodePtr parameter;
|
||||
FuncGraphPtr graph;
|
||||
// parameter_graph->load->graph
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
auto graph_cnode = cnode->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(graph_cnode);
|
||||
graph = GetValueNode<FuncGraphPtr>(graph_cnode->input(0));
|
||||
} else {
|
||||
graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
auto graph_out = ActualOp(graph->output());
|
||||
MS_EXCEPTION_IF_NULL(graph_out);
|
||||
auto parameter_list = graph->parameters();
|
||||
auto param_iter = std::find(parameter_list.begin(), parameter_list.end(), graph_out);
|
||||
auto use_cnode = use_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(use_cnode);
|
||||
if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
|
||||
MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
|
||||
}
|
||||
auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
|
||||
auto use_parameter_list = use_graph->parameters();
|
||||
parameter = use_parameter_list.at(pos - 1);
|
||||
// argument->load->graph
|
||||
if (param_iter == parameter_list.end()) {
|
||||
argument = graph_out;
|
||||
} else {
|
||||
auto param_pos = std::distance(parameter_list.begin(), param_iter);
|
||||
argument = cnode->input(param_pos + 1);
|
||||
}
|
||||
|
||||
// insert receive
|
||||
if (stage_ == user_stage) {
|
||||
auto recv = Reuse(argument, stage, ops, SRC_RANK);
|
||||
if (recv) {
|
||||
manager_->SetEdge(use_node, pos, recv);
|
||||
return nullptr;
|
||||
}
|
||||
return InsertReceive(main_graph_, argument, use_node, pos, user_stage, stage, micro, parameter);
|
||||
}
|
||||
// insert send
|
||||
if (Reuse(argument, user_stage, ops, DEST_RANK)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto send_out = InsertSend(main_graph_, argument, user_stage, stage_, micro);
|
||||
send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
|
||||
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
|
||||
return send_out.depend;
|
||||
}
|
||||
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
||||
OperatorAttrs depend_attrs;
|
||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||
|
@ -708,7 +823,7 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
|
|||
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>() || node->stage() == -1) {
|
||||
if (!node->isa<CNode>() || node->stage() == -1 || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
auto node_users = manager_->node_users()[node];
|
||||
|
@ -727,6 +842,15 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
|
|||
}
|
||||
if (node_stage < user_node_stage) {
|
||||
if (node_stage == stage_) {
|
||||
if (IsParameterGraph(node)) {
|
||||
auto send_depend =
|
||||
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, send_ops);
|
||||
if (!send_depend) {
|
||||
continue;
|
||||
}
|
||||
send_ops.insert(send_ops.begin(), send_depend);
|
||||
continue;
|
||||
}
|
||||
if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -737,8 +861,18 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
|
|||
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
|
||||
} else {
|
||||
if (!receive) {
|
||||
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro);
|
||||
receive_ops.push_back(receive);
|
||||
if (IsParameterGraph(node)) {
|
||||
receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second,
|
||||
receive_ops);
|
||||
if (!receive) {
|
||||
continue;
|
||||
}
|
||||
receive_ops.push_back(receive);
|
||||
} else {
|
||||
receive =
|
||||
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
|
||||
receive_ops.push_back(receive);
|
||||
}
|
||||
} else {
|
||||
manager_->SetEdge(user_node, user_pair.second, receive);
|
||||
}
|
||||
|
|
|
@ -62,12 +62,18 @@ class PipelineTransformer {
|
|||
|
||||
private:
|
||||
void CreateForwardGroup();
|
||||
AnfNodePtr ActualOp(const AnfNodePtr &node);
|
||||
bool IsParameterGraph(const AnfNodePtr &node);
|
||||
AnfNodeIndexSet GetActualOpUsers(const AnfNodePtr &node, NodeUsersMap *node_users_map);
|
||||
AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage,
|
||||
const ValuePtr µ, size_t pos, const std::vector<AnfNodePtr> ops);
|
||||
ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size);
|
||||
std::vector<AnfNodePtr> HandleSharedParameter();
|
||||
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int64_t user_node_stage,
|
||||
int64_t node_stage, const ValuePtr &value);
|
||||
AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
|
||||
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value);
|
||||
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
|
||||
const AnfNodePtr &graph_param);
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
|
||||
AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
|
||||
const std::string &tag);
|
||||
|
|
|
@ -326,17 +326,7 @@ ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
|
|||
}
|
||||
}
|
||||
if (py::hasattr(obj, STAGE_NAME)) {
|
||||
auto obj_stage = py::getattr(obj, STAGE_NAME);
|
||||
if (py::isinstance<py::bool_>(obj_stage)) {
|
||||
MS_LOG(EXCEPTION) << "The type of pipeline stage must be int, but got bool.";
|
||||
}
|
||||
if (!py::isinstance<py::int_>(obj_stage)) {
|
||||
MS_LOG(EXCEPTION) << "The type of pipeline stage must be int.";
|
||||
}
|
||||
auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
|
||||
if (stage < 0) {
|
||||
MS_LOG(EXCEPTION) << "Pipeline stage can't be less than 0, but got: " << stage;
|
||||
}
|
||||
func_graph->set_stage(stage);
|
||||
}
|
||||
return func_graph;
|
||||
|
|
|
@ -242,8 +242,12 @@ class Cell(Cell_):
|
|||
|
||||
@pipeline_stage.setter
|
||||
def pipeline_stage(self, value):
|
||||
if isinstance(value, bool):
|
||||
raise TypeError("'pipeline_stage' must be int type, but got bool.")
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("'pipeline_stage' must be int type.")
|
||||
if value < 0:
|
||||
raise TypeError("'pipeline_stage' can not less than 0.")
|
||||
self._pipeline_stage = value
|
||||
for item in self.trainable_params():
|
||||
item.add_pipeline_stage(value)
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell
|
||||
|
||||
|
||||
class DatasetLenet():
|
||||
def __init__(self, data, label, length=3):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.index = 1
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.data, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
def get_dataset_size(self):
|
||||
return 32
|
||||
|
||||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def get_batch_size(self):
|
||||
return 32
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
|
||||
return self
|
||||
|
||||
|
||||
class MatMulCell(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
|
||||
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.matmul1 = P.MatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.matmul(x, self.param)
|
||||
out = self.matmul1(out, self.param1)
|
||||
return out, self.param
|
||||
|
||||
|
||||
class MatMulCell2(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.matmul1 = P.MatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x, param):
|
||||
out = self.matmul(x, param)
|
||||
out = self.matmul1(out, self.param1)
|
||||
return out
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, param=None):
|
||||
super().__init__()
|
||||
self.cell1 = MatMulCell(strategy1, strategy2)
|
||||
self.cell1.pipeline_stage = 0
|
||||
self.cell2 = MatMulCell2(strategy1, strategy2)
|
||||
self.cell2.pipeline_stage = 1
|
||||
|
||||
def construct(self, x, label):
|
||||
out, param = self.cell1(x)
|
||||
out = self.cell2(out, param)
|
||||
return out
|
||||
|
||||
|
||||
def test_pipeline_split_stage0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(Net(strategy1, strategy2), 4)
|
||||
params = net.network.cell1.trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_pipeline_split_stage1():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(Net(strategy1, strategy2), 4)
|
||||
params = net.network.cell2.trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
Loading…
Reference in New Issue