forked from mindspore-Ecosystem/mindspore
auto insert VirtualDataset node for master
This commit is contained in:
parent
c88da99f77
commit
05189459ab
|
@ -50,7 +50,7 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
|
|||
return functional_op_module;
|
||||
}
|
||||
|
||||
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||
std::string op_path = GetOpPythonPath(op_name);
|
||||
py::module mod = py::module::import(common::SafeCStr(op_path));
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
|
@ -173,9 +173,9 @@ AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
|
|||
|
||||
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
|
||||
name_idx_++;
|
||||
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
|
||||
ValuePtr pyop_instance = CreateOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
|
||||
if (pyop_instance == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
|
||||
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreateOpInstance failed";
|
||||
}
|
||||
auto value_node = NewValueNode(pyop_instance);
|
||||
return value_node->cast<AnfNodePtr>();
|
||||
|
@ -184,9 +184,9 @@ AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorA
|
|||
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
|
||||
name_idx_++;
|
||||
OperatorAttrs attrs;
|
||||
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
|
||||
ValuePtr pyop_instance = CreateOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
|
||||
if (pyop_instance == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
|
||||
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreateOpInstance failed";
|
||||
}
|
||||
auto value_node = NewValueNode(pyop_instance);
|
||||
return value_node->cast<AnfNodePtr>();
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace parallel {
|
|||
std::string GetOpPythonPath(const OperatorName &op_name);
|
||||
|
||||
// Init python operator Instance
|
||||
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
|
||||
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
|
||||
|
||||
AnfNodePtr CreatTypeInt(int64_t value);
|
||||
AnfNodePtr CreatInt64Imm(int64_t value);
|
||||
|
|
|
@ -167,7 +167,7 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
|
|||
args1 = MakeValue(param_ptr->user_data<TensorLayout>()->opt_shard_group());
|
||||
args2 = MakeValue(param_ptr->param_info()->comm_fusion() + step * PIPELINE_FUSTION_OFFSET);
|
||||
OperatorAttrs attrs = {};
|
||||
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
|
||||
auto py_instance = CreateOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
// Set the attribute of the reduce scatter
|
||||
auto new_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
|
@ -187,7 +187,7 @@ void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &ma
|
|||
auto cnode = recv->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
OperatorAttrs attrs;
|
||||
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
|
||||
auto py_instance = CreateOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
|
||||
auto graph = cnode->func_graph();
|
||||
|
@ -584,7 +584,7 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
|
|||
MS_EXCEPTION_IF_NULL(end_cnode);
|
||||
auto end_prim = GetCNodePrimitive(end_node);
|
||||
OperatorAttrs attrs_;
|
||||
auto op = CreatOpInstance(attrs_, end_prim->name(), "");
|
||||
auto op = CreateOpInstance(attrs_, end_prim->name(), "");
|
||||
auto value_node = NewValueNode(op);
|
||||
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
|
||||
(void)new_prim->SetAttrs(end_prim->attrs());
|
||||
|
|
|
@ -407,7 +407,7 @@ static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node
|
|||
return;
|
||||
}
|
||||
OperatorAttrs attrs;
|
||||
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
|
||||
auto py_instance = CreateOpInstance(attrs, "_VirtualAdd", "grad_accu");
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
|
||||
auto graph = cnode->func_graph();
|
||||
|
|
|
@ -578,7 +578,7 @@ SendAttr PipelineTransformer::InsertSend(const AnfNodePtr ¶meter, int64_t us
|
|||
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
|
||||
Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
|
||||
OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
|
||||
auto send_op = CreatOpInstance(attrs, SEND, SEND);
|
||||
auto send_op = CreateOpInstance(attrs, SEND, SEND);
|
||||
auto send_node = NewValueNode(send_op);
|
||||
auto prim = GetValueNode<PrimitivePtr>(send_node);
|
||||
std::pair<OperatorInfoPtr, int> op_info_pair;
|
||||
|
@ -614,7 +614,7 @@ SendAttr PipelineTransformer::InsertSend(const AnfNodePtr ¶meter, int64_t us
|
|||
}
|
||||
send->AddPrimalAttr(MICRO, value);
|
||||
OperatorAttrs depend_attrs;
|
||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||
auto depend_op = CreateOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
|
||||
auto depend = main_graph_->NewCNode(depend_input);
|
||||
auto abstract = parameter->abstract();
|
||||
|
@ -662,7 +662,7 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
|
|||
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
|
||||
Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
|
||||
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
|
||||
auto recv_op = CreatOpInstance(attrs, RECEIVE, RECEIVE);
|
||||
auto recv_op = CreateOpInstance(attrs, RECEIVE, RECEIVE);
|
||||
std::vector<AnfNodePtr> recv_input;
|
||||
if (node->isa<Parameter>()) {
|
||||
recv_input = {NewValueNode(recv_op), node};
|
||||
|
@ -990,7 +990,7 @@ void PipelineTransformer::CoverSensShape() {
|
|||
auto sens_cnode = sens_graph_pair.first;
|
||||
MS_EXCEPTION_IF_NULL(sens_cnode);
|
||||
OperatorAttrs attrs;
|
||||
auto fill_op = CreatOpInstance(attrs, "Fill", "");
|
||||
auto fill_op = CreateOpInstance(attrs, "Fill", "");
|
||||
MS_EXCEPTION_IF_NULL(type_ptr_);
|
||||
MS_EXCEPTION_IF_NULL(shape_);
|
||||
std::vector<AnfNodePtr> fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_),
|
||||
|
|
|
@ -102,7 +102,7 @@ void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, co
|
|||
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
OperatorArgs arg_forward = op.second;
|
||||
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name);
|
||||
ValuePtr pyop_instance = CreateOpInstance(arg_forward.first, op.first, instance_name);
|
||||
MS_EXCEPTION_IF_NULL(pyop_instance);
|
||||
OperatorParams params = arg_forward.second;
|
||||
|
||||
|
@ -165,7 +165,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|||
}
|
||||
}
|
||||
|
||||
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
|
||||
ValuePtr pyop_instance = CreateOpInstance(arg_forward.first, op_name, instance_name);
|
||||
MS_EXCEPTION_IF_NULL(pyop_instance);
|
||||
OperatorParams params = arg_forward.second;
|
||||
|
||||
|
|
|
@ -243,9 +243,9 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
|||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node) {
|
||||
OperatorArgs arg_replace_op = replace_op.second;
|
||||
ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
|
||||
ValuePtr pyop_instance = CreateOpInstance(arg_replace_op.first, replace_op.first, instance_name);
|
||||
if (pyop_instance == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
|
||||
MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreateOpInstance failed";
|
||||
}
|
||||
OperatorParams params = arg_replace_op.second;
|
||||
if (node->inputs().size() < 2) {
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
|
@ -67,6 +70,121 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
|
|||
return rank_id / per_stage_rank_num;
|
||||
}
|
||||
|
||||
static bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto idx = NewValueNode(SizeToLong(index));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
|
||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
||||
idx->set_abstract(abstract_scalar);
|
||||
CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
tuple_get_item->set_scope(node->scope());
|
||||
auto input_abstract_tuple = node->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_abstract_tuple);
|
||||
auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
|
||||
tuple_get_item->set_abstract(tuple_get_item_abstract);
|
||||
return tuple_get_item;
|
||||
}
|
||||
|
||||
static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
|
||||
mindspore::parallel::OperatorAttrs attrs;
|
||||
ValuePtr pyop_instance = mindspore::parallel::CreateOpInstance(attrs, mindspore::parallel::VIRTUAL_DATA_SET,
|
||||
mindspore::parallel::VIRTUAL_DATA_SET);
|
||||
auto value_node = NewValueNode(pyop_instance);
|
||||
std::vector<AbstractBasePtr> abstract_list;
|
||||
std::vector<AnfNodePtr> virtual_dataset_node_inputs = {value_node};
|
||||
for (size_t index = 0; index < func_graph->get_inputs().size(); index++) {
|
||||
if (!HasAbstractMonad(func_graph->get_inputs()[index])) {
|
||||
auto graph_input_index = func_graph->get_inputs()[index];
|
||||
auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
|
||||
abstract_list.emplace_back(virtual_dataset_abstract);
|
||||
virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
|
||||
}
|
||||
}
|
||||
CNodePtr virtual_dataset_node = func_graph->NewCNode(virtual_dataset_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_node);
|
||||
virtual_dataset_node->set_in_forward_flag(true);
|
||||
virtual_dataset_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return virtual_dataset_node;
|
||||
}
|
||||
|
||||
static std::set<FuncGraphPtr> FindForwardGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
std::set<FuncGraphPtr> graph_sets;
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
FuncGraphPtr fun_graph = nullptr;
|
||||
if (expect_j_prim->name() == mindspore::parallel::J) {
|
||||
if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
|
||||
fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
|
||||
} else {
|
||||
fun_graph = node->func_graph();
|
||||
}
|
||||
graph_sets.insert(fun_graph);
|
||||
}
|
||||
}
|
||||
return graph_sets;
|
||||
}
|
||||
|
||||
static void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(all_nodes);
|
||||
for (auto forward_graph : forward_graph_set) {
|
||||
FuncGraphManagerPtr manager = forward_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
|
||||
auto node_user_map = manager->node_users();
|
||||
auto virtual_dataset_node = CreateVirtualDataset(forward_graph);
|
||||
std::map<size_t, CNodePtr> parameter_index_map;
|
||||
for (size_t index = 0; index < graph_inputs.size(); index++) {
|
||||
if (HasAbstractMonad(graph_inputs[index])) {
|
||||
continue;
|
||||
}
|
||||
auto node_users = node_user_map[graph_inputs[index]];
|
||||
for (auto node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
for (size_t input_index = 1; input_index < cnode->inputs().size(); input_index++) {
|
||||
bool is_node_input_flag = !(IsValueNode<mindspore::tensor::Tensor>(cnode->inputs()[input_index]) ||
|
||||
IsValueNode<ValueList>(cnode->inputs()[input_index]) ||
|
||||
IsValueNode<ValueTuple>(cnode->inputs()[input_index]));
|
||||
if (find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]) != graph_inputs.end() &&
|
||||
is_node_input_flag && !HasAbstractMonad(cnode->inputs()[input_index])) {
|
||||
auto node_input_iter = find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]);
|
||||
size_t node_input_index = node_input_iter - graph_inputs.begin();
|
||||
if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
|
||||
parameter_index_map[node_input_index] =
|
||||
CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
|
||||
}
|
||||
manager->SetEdge(cnode, input_index, parameter_index_map[node_input_index]);
|
||||
manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only auto_parallel and semi_auto_parallel support PipelineSplit
|
||||
bool PipelineSplit(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
|
@ -75,13 +193,19 @@ bool PipelineSplit(const ResourcePtr &res) {
|
|||
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
|
||||
return true;
|
||||
}
|
||||
auto manager = res->manager();
|
||||
auto root = res->func_graph();
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
if (!HasVirtualDataset(all_nodes)) {
|
||||
InsertVirtualDataset(root, all_nodes);
|
||||
}
|
||||
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
if (stage_num <= 1) {
|
||||
MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
|
||||
return true;
|
||||
}
|
||||
auto manager = res->manager();
|
||||
auto root = res->func_graph();
|
||||
auto global_rank = GetRank();
|
||||
auto world_group = GetWorldGroup();
|
||||
uint32_t world_rank_size = 0;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
constexpr size_t NODE_INPUT_NUM = 2;
|
||||
bool PipelineSplit(const ResourcePtr &res);
|
||||
std::string GetWorldGroup();
|
||||
} // namespace pipeline
|
||||
|
|
|
@ -17,12 +17,11 @@ from .. import nn
|
|||
from .._checkparam import Validator as validator
|
||||
from .._checkparam import Rel
|
||||
from ..common import dtype as mstype
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
|
||||
from ..nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
|
||||
from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
|
||||
from ..ops import functional as F
|
||||
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
|
||||
from ..parallel._utils import _get_pipeline_stages
|
||||
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||
from ..context import ParallelMode
|
||||
from .. import boost
|
||||
from .. import context
|
||||
|
||||
|
@ -197,9 +196,6 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
if loss_fn:
|
||||
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
||||
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network = _VirtualDatasetCell(network)
|
||||
|
||||
loss_scale = 1.0
|
||||
if config["loss_scale_manager"] is not None:
|
||||
loss_scale_manager = config["loss_scale_manager"]
|
||||
|
|
|
@ -34,7 +34,6 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
|
|||
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from ..nn.metrics import Loss
|
||||
from .. import nn
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from ..boost import AutoBoost
|
||||
from ..context import ParallelMode
|
||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
|
@ -304,8 +303,6 @@ class Model:
|
|||
boost_level=self._boost_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
elif self._loss_fn:
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network = _VirtualDatasetCell(network)
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
||||
|
@ -344,8 +341,6 @@ class Model:
|
|||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
if self._optimizer:
|
||||
self._eval_network = _VirtualDatasetCell(self._eval_network)
|
||||
if self._optimizer is None:
|
||||
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
||||
_set_multi_subgraphs()
|
||||
|
@ -355,7 +350,6 @@ class Model:
|
|||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
|
|
|
@ -273,7 +273,10 @@ TEST_F(TestStepParallel, ExtractShape3) {
|
|||
ASSERT_EQ(shape_test, shape_expect);
|
||||
}
|
||||
|
||||
TEST_F(TestStepParallel, CreatOpInstance) {
|
||||
/// Feature: test CreateOpInstance in auto parallel.
|
||||
/// Description: net with MicroBatchInterleaved in semi auto parallel.
|
||||
/// Expectation: success.
|
||||
TEST_F(TestStepParallel, CreateOpInstance) {
|
||||
ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
|
||||
ValuePtr attr1_value = MakeValue("0-1-2");
|
||||
Attr attr0 = std::make_pair("op", attr0_value);
|
||||
|
@ -282,7 +285,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
|||
OperatorName op_name = "AllReduce";
|
||||
OperatorParams operator_param;
|
||||
OperatorArgs args = std::make_pair(attrs, operator_param);
|
||||
auto op_instance = CreatOpInstance(args.first, op_name, "test");
|
||||
auto op_instance = CreateOpInstance(args.first, op_name, "test");
|
||||
ASSERT_TRUE(op_instance);
|
||||
PrimitivePyPtr allreduce_ptr = dyn_cast<PrimitivePy>(op_instance);
|
||||
ASSERT_TRUE(allreduce_ptr);
|
||||
|
@ -332,12 +335,15 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TestStepParallel, CreatOpInstance1) {
|
||||
/// Feature: test CreateOpInstance in auto parallel.
|
||||
/// Description: net with MicroBatchInterleaved in semi auto parallel.
|
||||
/// Expectation: success.
|
||||
TEST_F(TestStepParallel, CreateOpInstance1) {
|
||||
OperatorAttrs attrs;
|
||||
OperatorName op_name = "ABC";
|
||||
OperatorParams operator_param;
|
||||
OperatorArgs args = std::make_pair(attrs, operator_param);
|
||||
EXPECT_THROW({ CreatOpInstance(args.first, op_name, "test"); }, std::runtime_error);
|
||||
EXPECT_THROW({ CreateOpInstance(args.first, op_name, "test"); }, std::runtime_error);
|
||||
}
|
||||
|
||||
TEST_F(TestStepParallel, OperatorInstance) {
|
||||
|
|
|
@ -131,5 +131,5 @@ def test_double_subgraphs():
|
|||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, phase='train')
|
||||
num_ops = _cell_graph_executor._get_num_parallel_ops(net)
|
||||
expected_num = 7
|
||||
expected_num = 9
|
||||
assert expected_num == num_ops
|
||||
|
|
|
@ -66,7 +66,7 @@ def test_four_matmul_linear():
|
|||
return out
|
||||
|
||||
size = 64
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch", device_num=size, global_rank=0)
|
||||
strategy1 = ((2, 4), (4, 8))
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -261,6 +261,7 @@ def test_reshape_auto_5():
|
|||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32)
|
||||
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
|
||||
|
@ -291,6 +292,7 @@ def test_reshape_auto_6():
|
|||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
|
||||
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
|
||||
|
|
|
@ -392,6 +392,7 @@ def test_matmul_minimum_auto_parallel():
|
|||
out = self.minimum(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
|
||||
|
|
|
@ -119,18 +119,3 @@ def test_embeddinglookup_semi_auto1():
|
|||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_embeddinglookup_semi_auto2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
shape = [64, 32]
|
||||
offset = 0
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 1, 2), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU")))
|
||||
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
|
|
@ -196,6 +196,7 @@ def test_gatherv2_forward_all_reduce():
|
|||
"""
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
|
@ -210,6 +211,7 @@ def test_gatherv2_shard_batch_and_axis():
|
|||
"""
|
||||
strategy1 = ((4, 1), (2, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
|
@ -224,6 +226,7 @@ def test_gatherv2_split_axis_0_repeat_calc():
|
|||
"""
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
|
|
|
@ -73,12 +73,6 @@ def test_gathernd_dim2_default_batch_parallel():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_gathernd_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
|
||||
net = Net(1, _w1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_gathernd_repeat_calc():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 2, 4), (1, 2, 4))
|
||||
|
|
|
@ -87,6 +87,7 @@ def test_normal_split2():
|
|||
|
||||
|
||||
def test_normal_split3():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
|
||||
strategy1 = ((4, 8), (1, 4))
|
||||
strategy2 = ((1, 4, 8), (1, 4, 8))
|
||||
|
@ -96,6 +97,7 @@ def test_normal_split3():
|
|||
|
||||
|
||||
def test_normal_split_with_offset():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
|
|
|
@ -251,6 +251,7 @@ def test_pack_auto_parallel_axis1():
|
|||
|
||||
|
||||
def test_pack_auto_parallel_3_tensor():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w1, _w2, _w3)
|
||||
compile_net2(net)
|
||||
|
|
|
@ -99,15 +99,6 @@ def test_input_same_split():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_input_different_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1), (16, 1))
|
||||
strategy2 = ((4, 4), (4, 4))
|
||||
net = Net2(_w, strategy1, strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_parameter_different_group():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 2), (2, 1))
|
||||
|
|
|
@ -65,7 +65,7 @@ class Net(nn.Cell):
|
|||
|
||||
def test_gatherv2_semi_auto0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
@ -77,32 +77,11 @@ def test_gatherv2_semi_auto0():
|
|||
|
||||
|
||||
def test_gatherv2_semi_auto1():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto3():
|
||||
"""
|
||||
Feature: distribute operator SparseGatherV2 in auto parallel.
|
||||
Description: gather net with strategy in semi auto parallel, gather axis is 1.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -115,7 +94,12 @@ def test_gatherv2_semi_auto3():
|
|||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto4():
|
||||
def test_gatherv2_semi_auto2():
|
||||
"""
|
||||
Feature: distribute operator SparseGatherV2 in auto parallel.
|
||||
Description: gather net with strategy in semi auto parallel, gather axis is 1.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -128,7 +112,12 @@ def test_gatherv2_semi_auto4():
|
|||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto5():
|
||||
def test_gatherv2_semi_auto3():
|
||||
"""
|
||||
Feature: distribute operator SparseGatherV2 in auto parallel.
|
||||
Description: gather net with strategy in semi auto parallel, gather axis is 1.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
|
@ -141,7 +130,13 @@ def test_gatherv2_semi_auto5():
|
|||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto6():
|
||||
def test_gatherv2_semi_auto4():
|
||||
"""
|
||||
Feature: distribute operator SparseGatherV2 in auto parallel.
|
||||
Description: gather net with strategy in semi auto parallel, gather axis is 0.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, None, strategy2)))
|
||||
|
@ -153,7 +148,12 @@ def test_gatherv2_semi_auto6():
|
|||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_semi_auto7():
|
||||
def test_gatherv2_semi_auto5():
|
||||
"""
|
||||
Feature: distribute operator SparseGatherV2 in auto parallel.
|
||||
Description: gather net with strategy in semi auto parallel, gather axis is 1.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = GradWrap(NetWithLoss(Net(1, None, strategy2)))
|
||||
|
@ -166,6 +166,7 @@ def test_gatherv2_semi_auto7():
|
|||
|
||||
|
||||
def test_gatherv2_auto0():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(0)))
|
||||
net.set_auto_parallel()
|
||||
|
|
|
@ -124,9 +124,9 @@ def test_grad_sens_parameter_type():
|
|||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True)
|
||||
x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '')
|
||||
y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '')
|
||||
b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '')
|
||||
x_layout = ([64], [0, -1], [2, 32], 0, True, '')
|
||||
y_layout = ([64], [-1, -1], [32, 64], 0, True, '')
|
||||
b_layout = ([64], [0, -1], [1, 64], 0, True, '')
|
||||
sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
|
||||
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
|
Loading…
Reference in New Issue