revert insert VirtualDataset node for master
This commit is contained in:
parent
b4176e73a7
commit
2edf6ab33b
|
@ -14,9 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
|
@ -70,121 +67,6 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
|
|||
return rank_id / per_stage_rank_num;
|
||||
}
|
||||
|
||||
static bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto idx = NewValueNode(SizeToLong(index));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
tuple_get_item->set_scope(node->scope());
|
||||
auto input_abstract_tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_abstract_tuple);
|
||||
auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
|
||||
tuple_get_item->set_abstract(tuple_get_item_abstract);
|
||||
return tuple_get_item;
|
||||
}
|
||||
|
||||
static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
|
||||
mindspore::parallel::OperatorAttrs attrs;
|
||||
ValuePtr pyop_instance = mindspore::parallel::CreateOpInstance(attrs, mindspore::parallel::VIRTUAL_DATA_SET,
|
||||
mindspore::parallel::VIRTUAL_DATA_SET);
|
||||
auto value_node = NewValueNode(pyop_instance);
|
||||
std::vector<AbstractBasePtr> abstract_list;
|
||||
std::vector<AnfNodePtr> virtual_dataset_node_inputs = {value_node};
|
||||
for (size_t index = 0; index < func_graph->get_inputs().size(); index++) {
|
||||
if (!HasAbstractMonad(func_graph->get_inputs()[index])) {
|
||||
auto graph_input_index = func_graph->get_inputs()[index];
|
||||
auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
|
||||
abstract_list.emplace_back(virtual_dataset_abstract);
|
||||
virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
|
||||
}
|
||||
}
|
||||
CNodePtr virtual_dataset_node = func_graph->NewCNode(virtual_dataset_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_node);
|
||||
virtual_dataset_node->set_in_forward_flag(true);
|
||||
virtual_dataset_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return virtual_dataset_node;
|
||||
}
|
||||
|
||||
static std::set<FuncGraphPtr> FindForwardGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
std::set<FuncGraphPtr> graph_sets;
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
FuncGraphPtr fun_graph = nullptr;
|
||||
if (expect_j_prim->name() == mindspore::parallel::J) {
|
||||
if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
|
||||
fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
|
||||
} else {
|
||||
fun_graph = node->func_graph();
|
||||
}
|
||||
graph_sets.insert(fun_graph);
|
||||
}
|
||||
}
|
||||
return graph_sets;
|
||||
}
|
||||
|
||||
static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(all_nodes);
|
||||
for (auto forward_graph : forward_graph_set) {
|
||||
FuncGraphManagerPtr manager = forward_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
|
||||
auto node_user_map = manager->node_users();
|
||||
auto virtual_dataset_node = CreateVirtualDataset(forward_graph);
|
||||
std::map<size_t, CNodePtr> parameter_index_map;
|
||||
for (size_t index = 0; index < graph_inputs.size(); index++) {
|
||||
if (HasAbstractMonad(graph_inputs[index])) {
|
||||
continue;
|
||||
}
|
||||
auto node_users = node_user_map[graph_inputs[index]];
|
||||
for (auto node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
for (size_t input_index = 1; input_index < cnode->inputs().size(); input_index++) {
|
||||
bool is_node_input_flag = !(IsValueNode<mindspore::tensor::Tensor>(cnode->inputs()[input_index]) ||
|
||||
IsValueNode<ValueList>(cnode->inputs()[input_index]) ||
|
||||
IsValueNode<ValueTuple>(cnode->inputs()[input_index]));
|
||||
if (find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]) != graph_inputs.end() &&
|
||||
is_node_input_flag && !HasAbstractMonad(cnode->inputs()[input_index])) {
|
||||
auto node_input_iter = find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]);
|
||||
size_t node_input_index = node_input_iter - graph_inputs.begin();
|
||||
if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
|
||||
parameter_index_map[node_input_index] =
|
||||
CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
|
||||
}
|
||||
manager->SetEdge(cnode, input_index, parameter_index_map[node_input_index]);
|
||||
manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only auto_parallel and semi_auto_parallel support PipelineSplit
|
||||
bool PipelineSplit(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
|
@ -193,19 +75,13 @@ bool PipelineSplit(const ResourcePtr &res) {
|
|||
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
|
||||
return true;
|
||||
}
|
||||
auto manager = res->manager();
|
||||
auto root = res->func_graph();
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
if (!HasVirtualDataset(all_nodes)) {
|
||||
InsertVirtualDataset(root, all_nodes);
|
||||
}
|
||||
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
if (stage_num <= 1) {
|
||||
MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
|
||||
return true;
|
||||
}
|
||||
auto manager = res->manager();
|
||||
auto root = res->func_graph();
|
||||
auto global_rank = GetRank();
|
||||
auto world_group = GetWorldGroup();
|
||||
uint32_t world_rank_size = 0;
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
constexpr size_t NODE_INPUT_NUM = 2;
|
||||
bool PipelineSplit(const ResourcePtr &res);
|
||||
std::string GetWorldGroup();
|
||||
} // namespace pipeline
|
||||
|
|
|
@ -17,11 +17,12 @@ from .. import nn
|
|||
from .._checkparam import Validator as validator
|
||||
from .._checkparam import Rel
|
||||
from ..common import dtype as mstype
|
||||
from ..nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
|
||||
from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
|
||||
from ..ops import functional as F
|
||||
from ..parallel._utils import _get_pipeline_stages
|
||||
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
|
||||
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||
from ..context import ParallelMode
|
||||
from .. import boost
|
||||
from .. import context
|
||||
|
||||
|
@ -196,6 +197,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
if loss_fn:
|
||||
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
||||
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network = _VirtualDatasetCell(network)
|
||||
|
||||
loss_scale = 1.0
|
||||
if config["loss_scale_manager"] is not None:
|
||||
loss_scale_manager = config["loss_scale_manager"]
|
||||
|
|
|
@ -34,6 +34,7 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
|
|||
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from ..nn.metrics import Loss
|
||||
from .. import nn
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from ..boost import AutoBoost
|
||||
from ..context import ParallelMode
|
||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
|
@ -303,6 +304,8 @@ class Model:
|
|||
boost_level=self._boost_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
elif self._loss_fn:
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network = _VirtualDatasetCell(network)
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
||||
|
@ -341,6 +344,8 @@ class Model:
|
|||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
if self._optimizer:
|
||||
self._eval_network = _VirtualDatasetCell(self._eval_network)
|
||||
if self._optimizer is None:
|
||||
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
||||
_set_multi_subgraphs()
|
||||
|
@ -350,6 +355,7 @@ class Model:
|
|||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
|
|
|
@ -131,5 +131,5 @@ def test_double_subgraphs():
|
|||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, phase='train')
|
||||
num_ops = _cell_graph_executor._get_num_parallel_ops(net)
|
||||
expected_num = 9
|
||||
expected_num = 7
|
||||
assert expected_num == num_ops
|
||||
|
|
|
@ -66,7 +66,7 @@ def test_four_matmul_linear():
|
|||
return out
|
||||
|
||||
size = 64
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch", device_num=size, global_rank=0)
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
strategy1 = ((2, 4), (4, 8))
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -261,7 +261,6 @@ def test_reshape_auto_5():
|
|||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32)
|
||||
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
|
||||
|
@ -292,7 +291,6 @@ def test_reshape_auto_6():
|
|||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
|
||||
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
|
||||
|
|
|
@ -392,7 +392,6 @@ def test_matmul_minimum_auto_parallel():
|
|||
out = self.minimum(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
|
||||
|
|
|
@ -196,7 +196,6 @@ def test_gatherv2_forward_all_reduce():
|
|||
"""
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
|
@ -211,7 +210,6 @@ def test_gatherv2_shard_batch_and_axis():
|
|||
"""
|
||||
strategy1 = ((4, 1), (2, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
|
@ -226,7 +224,6 @@ def test_gatherv2_split_axis_0_repeat_calc():
|
|||
"""
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -87,7 +87,6 @@ def test_normal_split2():
|
|||
|
||||
|
||||
def test_normal_split3():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
|
||||
strategy1 = ((4, 8), (1, 4))
|
||||
strategy2 = ((1, 4, 8), (1, 4, 8))
|
||||
|
@ -97,7 +96,6 @@ def test_normal_split3():
|
|||
|
||||
|
||||
def test_normal_split_with_offset():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
|
|
|
@ -251,7 +251,6 @@ def test_pack_auto_parallel_axis1():
|
|||
|
||||
|
||||
def test_pack_auto_parallel_3_tensor():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w1, _w2, _w3)
|
||||
compile_net2(net)
|
||||
|
|
|
@ -166,7 +166,6 @@ def test_gatherv2_semi_auto5():
|
|||
|
||||
|
||||
def test_gatherv2_auto0():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(0)))
|
||||
net.set_auto_parallel()
|
||||
|
|
|
@ -124,9 +124,9 @@ def test_grad_sens_parameter_type():
|
|||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True)
|
||||
x_layout = ([64], [0, -1], [2, 32], 0, True, '')
|
||||
y_layout = ([64], [-1, -1], [32, 64], 0, True, '')
|
||||
b_layout = ([64], [0, -1], [1, 64], 0, True, '')
|
||||
x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '')
|
||||
y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '')
|
||||
b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '')
|
||||
sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
|
||||
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
|
Loading…
Reference in New Issue