auto insert VirtualDataset node for master

This commit is contained in:
lilei 2021-12-04 17:41:58 +08:00
parent c88da99f77
commit 05189459ab
24 changed files with 203 additions and 102 deletions

View File

@ -50,7 +50,7 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
return functional_op_module;
}
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
std::string op_path = GetOpPythonPath(op_name);
py::module mod = py::module::import(common::SafeCStr(op_path));
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
@ -173,9 +173,9 @@ AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
name_idx_++;
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
ValuePtr pyop_instance = CreateOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
if (pyop_instance == nullptr) {
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreateOpInstance failed";
}
auto value_node = NewValueNode(pyop_instance);
return value_node->cast<AnfNodePtr>();
@ -184,9 +184,9 @@ AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorA
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
name_idx_++;
OperatorAttrs attrs;
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
ValuePtr pyop_instance = CreateOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
if (pyop_instance == nullptr) {
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreateOpInstance failed";
}
auto value_node = NewValueNode(pyop_instance);
return value_node->cast<AnfNodePtr>();

View File

@ -35,7 +35,7 @@ namespace parallel {
std::string GetOpPythonPath(const OperatorName &op_name);
// Init python operator Instance
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
AnfNodePtr CreatTypeInt(int64_t value);
AnfNodePtr CreatInt64Imm(int64_t value);

View File

@ -167,7 +167,7 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
args1 = MakeValue(param_ptr->user_data<TensorLayout>()->opt_shard_group());
args2 = MakeValue(param_ptr->param_info()->comm_fusion() + step * PIPELINE_FUSTION_OFFSET);
OperatorAttrs attrs = {};
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
auto py_instance = CreateOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
auto value_node = NewValueNode(py_instance);
// Set the attribute of the reduce scatter
auto new_prim = GetValueNode<PrimitivePtr>(value_node);
@ -187,7 +187,7 @@ void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &ma
auto cnode = recv->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorAttrs attrs;
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
auto py_instance = CreateOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
auto value_node = NewValueNode(py_instance);
std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
auto graph = cnode->func_graph();
@ -584,7 +584,7 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
MS_EXCEPTION_IF_NULL(end_cnode);
auto end_prim = GetCNodePrimitive(end_node);
OperatorAttrs attrs_;
auto op = CreatOpInstance(attrs_, end_prim->name(), "");
auto op = CreateOpInstance(attrs_, end_prim->name(), "");
auto value_node = NewValueNode(op);
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
(void)new_prim->SetAttrs(end_prim->attrs());

View File

@ -407,7 +407,7 @@ static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node
return;
}
OperatorAttrs attrs;
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
auto py_instance = CreateOpInstance(attrs, "_VirtualAdd", "grad_accu");
auto value_node = NewValueNode(py_instance);
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
auto graph = cnode->func_graph();

View File

@ -578,7 +578,7 @@ SendAttr PipelineTransformer::InsertSend(const AnfNodePtr &parameter, int64_t us
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
auto send_op = CreatOpInstance(attrs, SEND, SEND);
auto send_op = CreateOpInstance(attrs, SEND, SEND);
auto send_node = NewValueNode(send_op);
auto prim = GetValueNode<PrimitivePtr>(send_node);
std::pair<OperatorInfoPtr, int> op_info_pair;
@ -614,7 +614,7 @@ SendAttr PipelineTransformer::InsertSend(const AnfNodePtr &parameter, int64_t us
}
send->AddPrimalAttr(MICRO, value);
OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
auto depend_op = CreateOpInstance(depend_attrs, DEPEND, DEPEND);
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
auto depend = main_graph_->NewCNode(depend_input);
auto abstract = parameter->abstract();
@ -662,7 +662,7 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
auto recv_op = CreatOpInstance(attrs, RECEIVE, RECEIVE);
auto recv_op = CreateOpInstance(attrs, RECEIVE, RECEIVE);
std::vector<AnfNodePtr> recv_input;
if (node->isa<Parameter>()) {
recv_input = {NewValueNode(recv_op), node};
@ -990,7 +990,7 @@ void PipelineTransformer::CoverSensShape() {
auto sens_cnode = sens_graph_pair.first;
MS_EXCEPTION_IF_NULL(sens_cnode);
OperatorAttrs attrs;
auto fill_op = CreatOpInstance(attrs, "Fill", "");
auto fill_op = CreateOpInstance(attrs, "Fill", "");
MS_EXCEPTION_IF_NULL(type_ptr_);
MS_EXCEPTION_IF_NULL(shape_);
std::vector<AnfNodePtr> fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_),

View File

@ -102,7 +102,7 @@ void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, co
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
MS_EXCEPTION_IF_NULL(node);
OperatorArgs arg_forward = op.second;
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name);
ValuePtr pyop_instance = CreateOpInstance(arg_forward.first, op.first, instance_name);
MS_EXCEPTION_IF_NULL(pyop_instance);
OperatorParams params = arg_forward.second;
@ -165,7 +165,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
}
}
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
ValuePtr pyop_instance = CreateOpInstance(arg_forward.first, op_name, instance_name);
MS_EXCEPTION_IF_NULL(pyop_instance);
OperatorParams params = arg_forward.second;

View File

@ -243,9 +243,9 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node) {
OperatorArgs arg_replace_op = replace_op.second;
ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
ValuePtr pyop_instance = CreateOpInstance(arg_replace_op.first, replace_op.first, instance_name);
if (pyop_instance == nullptr) {
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreateOpInstance failed";
}
OperatorParams params = arg_replace_op.second;
if (node->inputs().size() < 2) {

View File

@ -14,6 +14,9 @@
* limitations under the License.
*/
#include <set>
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "pipeline/jit/pipeline_split.h"
@ -67,6 +70,121 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
return rank_id / per_stage_rank_num;
}
static bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
return true;
}
}
return false;
}
static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto idx = NewValueNode(SizeToLong(index));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
idx->set_abstract(abstract_scalar);
CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_get_item);
tuple_get_item->set_scope(node->scope());
auto input_abstract_tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(input_abstract_tuple);
auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
tuple_get_item->set_abstract(tuple_get_item_abstract);
return tuple_get_item;
}
static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
mindspore::parallel::OperatorAttrs attrs;
ValuePtr pyop_instance = mindspore::parallel::CreateOpInstance(attrs, mindspore::parallel::VIRTUAL_DATA_SET,
mindspore::parallel::VIRTUAL_DATA_SET);
auto value_node = NewValueNode(pyop_instance);
std::vector<AbstractBasePtr> abstract_list;
std::vector<AnfNodePtr> virtual_dataset_node_inputs = {value_node};
for (size_t index = 0; index < func_graph->get_inputs().size(); index++) {
if (!HasAbstractMonad(func_graph->get_inputs()[index])) {
auto graph_input_index = func_graph->get_inputs()[index];
auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
abstract_list.emplace_back(virtual_dataset_abstract);
virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
}
}
CNodePtr virtual_dataset_node = func_graph->NewCNode(virtual_dataset_node_inputs);
MS_EXCEPTION_IF_NULL(virtual_dataset_node);
virtual_dataset_node->set_in_forward_flag(true);
virtual_dataset_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return virtual_dataset_node;
}
static std::set<FuncGraphPtr> FindForwardGraph(const std::vector<AnfNodePtr> &all_nodes) {
std::set<FuncGraphPtr> graph_sets;
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
FuncGraphPtr fun_graph = nullptr;
if (expect_j_prim->name() == mindspore::parallel::J) {
if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
} else {
fun_graph = node->func_graph();
}
graph_sets.insert(fun_graph);
}
}
return graph_sets;
}
static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
MS_EXCEPTION_IF_NULL(root);
std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(all_nodes);
for (auto forward_graph : forward_graph_set) {
FuncGraphManagerPtr manager = forward_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
auto node_user_map = manager->node_users();
auto virtual_dataset_node = CreateVirtualDataset(forward_graph);
std::map<size_t, CNodePtr> parameter_index_map;
for (size_t index = 0; index < graph_inputs.size(); index++) {
if (HasAbstractMonad(graph_inputs[index])) {
continue;
}
auto node_users = node_user_map[graph_inputs[index]];
for (auto node_user : node_users) {
auto cnode = node_user.first->cast<CNodePtr>();
for (size_t input_index = 1; input_index < cnode->inputs().size(); input_index++) {
bool is_node_input_flag = !(IsValueNode<mindspore::tensor::Tensor>(cnode->inputs()[input_index]) ||
IsValueNode<ValueList>(cnode->inputs()[input_index]) ||
IsValueNode<ValueTuple>(cnode->inputs()[input_index]));
if (find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]) != graph_inputs.end() &&
is_node_input_flag && !HasAbstractMonad(cnode->inputs()[input_index])) {
auto node_input_iter = find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]);
size_t node_input_index = node_input_iter - graph_inputs.begin();
if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
parameter_index_map[node_input_index] =
CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
}
manager->SetEdge(cnode, input_index, parameter_index_map[node_input_index]);
manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
}
}
}
}
}
}
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
@ -75,13 +193,19 @@ bool PipelineSplit(const ResourcePtr &res) {
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
return true;
}
auto manager = res->manager();
auto root = res->func_graph();
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
if (!HasVirtualDataset(all_nodes)) {
InsertVirtualDataset(root, all_nodes);
}
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
if (stage_num <= 1) {
MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
return true;
}
auto manager = res->manager();
auto root = res->func_graph();
auto global_rank = GetRank();
auto world_group = GetWorldGroup();
uint32_t world_rank_size = 0;

View File

@ -22,6 +22,7 @@
namespace mindspore {
namespace pipeline {
constexpr size_t NODE_INPUT_NUM = 2;
bool PipelineSplit(const ResourcePtr &res);
std::string GetWorldGroup();
} // namespace pipeline

View File

@ -17,12 +17,11 @@ from .. import nn
from .._checkparam import Validator as validator
from .._checkparam import Rel
from ..common import dtype as mstype
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
from ..nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
from ..parallel._utils import _get_pipeline_stages
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from ..context import ParallelMode
from .. import boost
from .. import context
@ -197,9 +196,6 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
if loss_fn:
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)
loss_scale = 1.0
if config["loss_scale_manager"] is not None:
loss_scale_manager = config["loss_scale_manager"]

View File

@ -34,7 +34,6 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..boost import AutoBoost
from ..context import ParallelMode
from ..parallel._cost_model_context import _set_multi_subgraphs
@ -304,8 +303,6 @@ class Model:
boost_level=self._boost_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None
@ -344,8 +341,6 @@ class Model:
self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if self._optimizer:
self._eval_network = _VirtualDatasetCell(self._eval_network)
if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs()
@ -355,7 +350,6 @@ class Model:
"""Build the network for prediction."""
self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network)
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
self._predict_network.set_auto_parallel()

View File

@ -273,7 +273,10 @@ TEST_F(TestStepParallel, ExtractShape3) {
ASSERT_EQ(shape_test, shape_expect);
}
TEST_F(TestStepParallel, CreatOpInstance) {
/// Feature: test CreateOpInstance in auto parallel.
/// Description: net with MicroBatchInterleaved in semi auto parallel.
/// Expectation: success.
TEST_F(TestStepParallel, CreateOpInstance) {
ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
ValuePtr attr1_value = MakeValue("0-1-2");
Attr attr0 = std::make_pair("op", attr0_value);
@ -282,7 +285,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
OperatorName op_name = "AllReduce";
OperatorParams operator_param;
OperatorArgs args = std::make_pair(attrs, operator_param);
auto op_instance = CreatOpInstance(args.first, op_name, "test");
auto op_instance = CreateOpInstance(args.first, op_name, "test");
ASSERT_TRUE(op_instance);
PrimitivePyPtr allreduce_ptr = dyn_cast<PrimitivePy>(op_instance);
ASSERT_TRUE(allreduce_ptr);
@ -332,12 +335,15 @@ TEST_F(TestStepParallel, CreatOpInstance) {
}
}
TEST_F(TestStepParallel, CreatOpInstance1) {
/// Feature: test CreateOpInstance in auto parallel.
/// Description: net with MicroBatchInterleaved in semi auto parallel.
/// Expectation: success.
TEST_F(TestStepParallel, CreateOpInstance1) {
OperatorAttrs attrs;
OperatorName op_name = "ABC";
OperatorParams operator_param;
OperatorArgs args = std::make_pair(attrs, operator_param);
EXPECT_THROW({ CreatOpInstance(args.first, op_name, "test"); }, std::runtime_error);
EXPECT_THROW({ CreateOpInstance(args.first, op_name, "test"); }, std::runtime_error);
}
TEST_F(TestStepParallel, OperatorInstance) {

View File

@ -131,5 +131,5 @@ def test_double_subgraphs():
net.set_train()
_cell_graph_executor.compile(net, x, phase='train')
num_ops = _cell_graph_executor._get_num_parallel_ops(net)
expected_num = 7
expected_num = 9
assert expected_num == num_ops

View File

@ -66,7 +66,7 @@ def test_four_matmul_linear():
return out
size = 64
context.set_auto_parallel_context(device_num=size, global_rank=0)
context.set_auto_parallel_context(dataset_strategy="full_batch", device_num=size, global_rank=0)
strategy1 = ((2, 4), (4, 8))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)

View File

@ -261,6 +261,7 @@ def test_reshape_auto_5():
return out
size = 8
context.set_auto_parallel_context(dataset_strategy="full_batch")
x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32)
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
@ -291,6 +292,7 @@ def test_reshape_auto_6():
return out
size = 8
context.set_auto_parallel_context(dataset_strategy="full_batch")
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))

View File

@ -392,6 +392,7 @@ def test_matmul_minimum_auto_parallel():
out = self.minimum(out, b)
return out
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net()))

View File

@ -119,18 +119,3 @@ def test_embeddinglookup_semi_auto1():
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_cell_graph_executor.compile(net, x, y)
def test_embeddinglookup_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
shape = [64, 32]
offset = 0
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 1, 2), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_cell_graph_executor.compile(net, x, y)

View File

@ -196,6 +196,7 @@ def test_gatherv2_forward_all_reduce():
"""
strategy1 = ((8, 1), (1, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
context.set_auto_parallel_context(dataset_strategy="full_batch")
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
@ -210,6 +211,7 @@ def test_gatherv2_shard_batch_and_axis():
"""
strategy1 = ((4, 1), (2, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
context.set_auto_parallel_context(dataset_strategy="full_batch")
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
@ -224,6 +226,7 @@ def test_gatherv2_split_axis_0_repeat_calc():
"""
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
context.set_auto_parallel_context(dataset_strategy="full_batch")
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)

View File

@ -73,12 +73,6 @@ def test_gathernd_dim2_default_batch_parallel():
compile_net(net)
def test_gathernd_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
net = Net(1, _w1)
compile_net(net)
def test_gathernd_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 2, 4), (1, 2, 4))

View File

@ -87,6 +87,7 @@ def test_normal_split2():
def test_normal_split3():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
strategy1 = ((4, 8), (1, 4))
strategy2 = ((1, 4, 8), (1, 4, 8))
@ -96,6 +97,7 @@ def test_normal_split3():
def test_normal_split_with_offset():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))

View File

@ -251,6 +251,7 @@ def test_pack_auto_parallel_axis1():
def test_pack_auto_parallel_3_tensor():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w1, _w2, _w3)
compile_net2(net)

View File

@ -99,15 +99,6 @@ def test_input_same_split():
compile_net(net)
def test_input_different_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1), (16, 1))
strategy2 = ((4, 4), (4, 4))
net = Net2(_w, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_parameter_different_group():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 2), (2, 1))

View File

@ -65,7 +65,7 @@ class Net(nn.Cell):
def test_gatherv2_semi_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel()
@ -77,32 +77,11 @@ def test_gatherv2_semi_auto0():
def test_gatherv2_semi_auto1():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_cell_graph_executor.compile(net, x, y)
def test_gatherv2_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_cell_graph_executor.compile(net, x, y)
def test_gatherv2_semi_auto3():
"""
Feature: distribute operator SparseGatherV2 in auto parallel.
Description: gather net with strategy in semi auto parallel, gather axis is 1.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
@ -115,7 +94,12 @@ def test_gatherv2_semi_auto3():
_cell_graph_executor.compile(net, x, y)
def test_gatherv2_semi_auto4():
def test_gatherv2_semi_auto2():
"""
Feature: distribute operator SparseGatherV2 in auto parallel.
Description: gather net with strategy in semi auto parallel, gather axis is 1.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
@ -128,7 +112,12 @@ def test_gatherv2_semi_auto4():
_cell_graph_executor.compile(net, x, y)
def test_gatherv2_semi_auto5():
def test_gatherv2_semi_auto3():
"""
Feature: distribute operator SparseGatherV2 in auto parallel.
Description: gather net with strategy in semi auto parallel, gather axis is 1.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
@ -141,7 +130,13 @@ def test_gatherv2_semi_auto5():
_cell_graph_executor.compile(net, x, y)
def test_gatherv2_semi_auto6():
def test_gatherv2_semi_auto4():
"""
Feature: distribute operator SparseGatherV2 in auto parallel.
Description: gather net with strategy in semi auto parallel, gather axis is 0.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(0, None, strategy2)))
@ -153,7 +148,12 @@ def test_gatherv2_semi_auto6():
_cell_graph_executor.compile(net, x, y)
def test_gatherv2_semi_auto7():
def test_gatherv2_semi_auto5():
"""
Feature: distribute operator SparseGatherV2 in auto parallel.
Description: gather net with strategy in semi auto parallel, gather axis is 1.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy2 = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(1, None, strategy2)))
@ -166,6 +166,7 @@ def test_gatherv2_semi_auto7():
def test_gatherv2_auto0():
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net(0)))
net.set_auto_parallel()

View File

@ -124,9 +124,9 @@ def test_grad_sens_parameter_type():
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True)
x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '')
y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '')
b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '')
x_layout = ([64], [0, -1], [2, 32], 0, True, '')
y_layout = ([64], [-1, -1], [32, 64], 0, True, '')
b_layout = ([64], [0, -1], [1, 64], 0, True, '')
sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
assert net.parameter_layout_dict == expect_dict