revert insert VirtualDataset node for master

This commit is contained in:
lilei 2021-12-08 17:19:08 +08:00
parent b4176e73a7
commit 2edf6ab33b
13 changed files with 19 additions and 144 deletions

View File

@ -14,9 +14,6 @@
* limitations under the License.
*/
#include <set>
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "pipeline/jit/pipeline_split.h"
@ -70,121 +67,6 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
return rank_id / per_stage_rank_num;
}
static bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
return true;
}
}
return false;
}
static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto idx = NewValueNode(SizeToLong(index));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_get_item);
tuple_get_item->set_scope(node->scope());
auto input_abstract_tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(input_abstract_tuple);
auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
tuple_get_item->set_abstract(tuple_get_item_abstract);
return tuple_get_item;
}
static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
mindspore::parallel::OperatorAttrs attrs;
ValuePtr pyop_instance = mindspore::parallel::CreateOpInstance(attrs, mindspore::parallel::VIRTUAL_DATA_SET,
mindspore::parallel::VIRTUAL_DATA_SET);
auto value_node = NewValueNode(pyop_instance);
std::vector<AbstractBasePtr> abstract_list;
std::vector<AnfNodePtr> virtual_dataset_node_inputs = {value_node};
for (size_t index = 0; index < func_graph->get_inputs().size(); index++) {
if (!HasAbstractMonad(func_graph->get_inputs()[index])) {
auto graph_input_index = func_graph->get_inputs()[index];
auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
abstract_list.emplace_back(virtual_dataset_abstract);
virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
}
}
CNodePtr virtual_dataset_node = func_graph->NewCNode(virtual_dataset_node_inputs);
MS_EXCEPTION_IF_NULL(virtual_dataset_node);
virtual_dataset_node->set_in_forward_flag(true);
virtual_dataset_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return virtual_dataset_node;
}
static std::set<FuncGraphPtr> FindForwardGraph(const std::vector<AnfNodePtr> &all_nodes) {
std::set<FuncGraphPtr> graph_sets;
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
FuncGraphPtr fun_graph = nullptr;
if (expect_j_prim->name() == mindspore::parallel::J) {
if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
} else {
fun_graph = node->func_graph();
}
graph_sets.insert(fun_graph);
}
}
return graph_sets;
}
static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
MS_EXCEPTION_IF_NULL(root);
std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(all_nodes);
for (auto forward_graph : forward_graph_set) {
FuncGraphManagerPtr manager = forward_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
auto node_user_map = manager->node_users();
auto virtual_dataset_node = CreateVirtualDataset(forward_graph);
std::map<size_t, CNodePtr> parameter_index_map;
for (size_t index = 0; index < graph_inputs.size(); index++) {
if (HasAbstractMonad(graph_inputs[index])) {
continue;
}
auto node_users = node_user_map[graph_inputs[index]];
for (auto node_user : node_users) {
auto cnode = node_user.first->cast<CNodePtr>();
for (size_t input_index = 1; input_index < cnode->inputs().size(); input_index++) {
bool is_node_input_flag = !(IsValueNode<mindspore::tensor::Tensor>(cnode->inputs()[input_index]) ||
IsValueNode<ValueList>(cnode->inputs()[input_index]) ||
IsValueNode<ValueTuple>(cnode->inputs()[input_index]));
if (find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]) != graph_inputs.end() &&
is_node_input_flag && !HasAbstractMonad(cnode->inputs()[input_index])) {
auto node_input_iter = find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]);
size_t node_input_index = node_input_iter - graph_inputs.begin();
if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
parameter_index_map[node_input_index] =
CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
}
manager->SetEdge(cnode, input_index, parameter_index_map[node_input_index]);
manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
}
}
}
}
}
}
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
@ -193,19 +75,13 @@ bool PipelineSplit(const ResourcePtr &res) {
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
return true;
}
auto manager = res->manager();
auto root = res->func_graph();
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
if (!HasVirtualDataset(all_nodes)) {
InsertVirtualDataset(root, all_nodes);
}
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
if (stage_num <= 1) {
MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
return true;
}
auto manager = res->manager();
auto root = res->func_graph();
auto global_rank = GetRank();
auto world_group = GetWorldGroup();
uint32_t world_rank_size = 0;

View File

@ -22,7 +22,6 @@
namespace mindspore {
namespace pipeline {
constexpr size_t NODE_INPUT_NUM = 2;
bool PipelineSplit(const ResourcePtr &res);
std::string GetWorldGroup();
} // namespace pipeline

View File

@ -17,11 +17,12 @@ from .. import nn
from .._checkparam import Validator as validator
from .._checkparam import Rel
from ..common import dtype as mstype
from ..nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
from ..ops import functional as F
from ..parallel._utils import _get_pipeline_stages
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from ..context import ParallelMode
from .. import boost
from .. import context
@ -196,6 +197,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
if loss_fn:
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)
loss_scale = 1.0
if config["loss_scale_manager"] is not None:
loss_scale_manager = config["loss_scale_manager"]

View File

@ -34,6 +34,7 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..boost import AutoBoost
from ..context import ParallelMode
from ..parallel._cost_model_context import _set_multi_subgraphs
@ -303,6 +304,8 @@ class Model:
boost_level=self._boost_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None
@ -341,6 +344,8 @@ class Model:
self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if self._optimizer:
self._eval_network = _VirtualDatasetCell(self._eval_network)
if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs()
@ -350,6 +355,7 @@ class Model:
"""Build the network for prediction."""
self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network)
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
self._predict_network.set_auto_parallel()

View File

@ -131,5 +131,5 @@ def test_double_subgraphs():
net.set_train()
_cell_graph_executor.compile(net, x, phase='train')
num_ops = _cell_graph_executor._get_num_parallel_ops(net)
expected_num = 9
expected_num = 7
assert expected_num == num_ops

View File

@ -66,7 +66,7 @@ def test_four_matmul_linear():
return out
size = 64
context.set_auto_parallel_context(dataset_strategy="full_batch", device_num=size, global_rank=0)
context.set_auto_parallel_context(device_num=size, global_rank=0)
strategy1 = ((2, 4), (4, 8))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)

View File

@ -261,7 +261,6 @@ def test_reshape_auto_5():
return out
size = 8
context.set_auto_parallel_context(dataset_strategy="full_batch")
x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32)
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
@ -292,7 +291,6 @@ def test_reshape_auto_6():
return out
size = 8
context.set_auto_parallel_context(dataset_strategy="full_batch")
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))

View File

@ -392,7 +392,6 @@ def test_matmul_minimum_auto_parallel():
out = self.minimum(out, b)
return out
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net()))

View File

@ -196,7 +196,6 @@ def test_gatherv2_forward_all_reduce():
"""
strategy1 = ((8, 1), (1, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
context.set_auto_parallel_context(dataset_strategy="full_batch")
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
@ -211,7 +210,6 @@ def test_gatherv2_shard_batch_and_axis():
"""
strategy1 = ((4, 1), (2, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
context.set_auto_parallel_context(dataset_strategy="full_batch")
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
@ -226,7 +224,6 @@ def test_gatherv2_split_axis_0_repeat_calc():
"""
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
context.set_auto_parallel_context(dataset_strategy="full_batch")
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)

View File

@ -87,7 +87,6 @@ def test_normal_split2():
def test_normal_split3():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
strategy1 = ((4, 8), (1, 4))
strategy2 = ((1, 4, 8), (1, 4, 8))
@ -97,7 +96,6 @@ def test_normal_split3():
def test_normal_split_with_offset():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))

View File

@ -251,7 +251,6 @@ def test_pack_auto_parallel_axis1():
def test_pack_auto_parallel_3_tensor():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, _w2, _w3)
compile_net2(net)

View File

@ -166,7 +166,6 @@ def test_gatherv2_semi_auto5():
def test_gatherv2_auto0():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net(0)))
net.set_auto_parallel()

View File

@ -124,9 +124,9 @@ def test_grad_sens_parameter_type():
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True)
x_layout = ([64], [0, -1], [2, 32], 0, True, '')
y_layout = ([64], [-1, -1], [32, 64], 0, True, '')
b_layout = ([64], [0, -1], [1, 64], 0, True, '')
x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '')
y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '')
b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '')
sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
assert net.parameter_layout_dict == expect_dict