forked from mindspore-Ecosystem/mindspore
!698 [Auto parallel] Support multi-subgraphs in auto-parallel
Merge pull request !698 from Xiaoda/support-wide-deep-in-auto-parallel
This commit is contained in:
commit
ef71ae941f
|
@ -44,6 +44,7 @@ namespace parallel {
|
|||
#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16
|
||||
#define DEFAULT_FULLY_USE_DEVICES true
|
||||
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
|
||||
#define DEFAULT_IS_MULTI_SUBGRAPHS false
|
||||
|
||||
class CostGraph;
|
||||
using CostGraphPtr = std::shared_ptr<CostGraph>;
|
||||
|
|
|
@ -46,6 +46,7 @@ void CostModelContext::ResetCostModel() {
|
|||
costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
|
||||
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
|
||||
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
|
||||
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
|
||||
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
|
||||
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
|
||||
|
@ -84,6 +85,7 @@ void CostModelContext::set_costmodel_communi_const(double cm_communi_const) {
|
|||
|
||||
void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; }
|
||||
|
||||
void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; }
|
||||
void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) {
|
||||
costmodel_allreduce_fusion_algorithm_ = algorithm;
|
||||
}
|
||||
|
|
|
@ -67,6 +67,9 @@ class CostModelContext {
|
|||
void set_costmodel_communi_bias(double);
|
||||
double costmodel_communi_bias() const { return costmodel_communi_bias_; }
|
||||
|
||||
void set_multi_subgraphs(bool);
|
||||
bool is_multi_subgraphs() const { return is_multi_subgraphs_; }
|
||||
|
||||
void set_costmodel_allreduce_fusion_algorithm(int32_t);
|
||||
int32_t costmodel_allreduce_fusion_algorithm() const { return costmodel_allreduce_fusion_algorithm_; }
|
||||
|
||||
|
@ -138,6 +141,8 @@ class CostModelContext {
|
|||
// COST_MODEL_COMMUNI_BIAS
|
||||
double costmodel_communi_bias_;
|
||||
|
||||
bool is_multi_subgraphs_;
|
||||
|
||||
int32_t costmodel_allreduce_fusion_algorithm_;
|
||||
|
||||
int32_t costmodel_allreduce_fusion_times_;
|
||||
|
|
|
@ -426,13 +426,13 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
return operator_info;
|
||||
}
|
||||
|
||||
Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
|
||||
// Using CNode's UniqueIds to construct nodes
|
||||
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
|
||||
entire_costgraph = std::make_shared<CostGraph>();
|
||||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
bool new_operator = true, first_operator = true;
|
||||
std::string first_operator_cnode;
|
||||
size_t current_op_index = 0;
|
||||
// The map from CNode's UniqueId to its operatorInfo
|
||||
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
||||
|
||||
// Step 1
|
||||
for (auto &node : all_nodes) {
|
||||
|
@ -449,12 +449,8 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
|
|||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
// When visiting the second subgraph, use the corresponding operatorInfo which already created
|
||||
bool modify_new_operator = (new_operator) && (!first_operator) && (cnode->UniqueId() == first_operator_cnode);
|
||||
if (modify_new_operator) {
|
||||
new_operator = false;
|
||||
}
|
||||
if (new_operator) {
|
||||
auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
|
||||
if (search_cnode == from_cnode_to_info.end()) {
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode);
|
||||
if (operator_info == nullptr) {
|
||||
return FAILED;
|
||||
|
@ -465,14 +461,67 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
|
|||
|
||||
entire_costgraph->AddOperator(operator_info);
|
||||
(void)cnode->set_operator_info(operator_info);
|
||||
if (first_operator) {
|
||||
first_operator_cnode = cnode->UniqueId();
|
||||
first_operator = false;
|
||||
}
|
||||
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
|
||||
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
|
||||
// Needed by rec_parser
|
||||
entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
|
||||
} else {
|
||||
auto current_op_ptr = entire_costgraph->FindOperatorByIndex(current_op_index);
|
||||
// Two CNODEs' UniqueIds should not be equal
|
||||
MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
<< " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// Using CNode's UniqueIdThroughCopys to construct nodes
|
||||
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
|
||||
entire_costgraph = std::make_shared<CostGraph>();
|
||||
entire_costgraph->SetDeviceMemoryAndCostParameter();
|
||||
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
|
||||
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
|
||||
|
||||
for (auto &node : all_nodes) {
|
||||
// NOTE: we only care about splittable Primitive operators
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
|
||||
if (bool_result) {
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (!IsAutoParallelCareNode(cnode)) {
|
||||
continue;
|
||||
}
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
|
||||
// Find the operatorInfo if it exists
|
||||
auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
|
||||
if (search_cnode == from_cnode_to_info.end()) {
|
||||
// In this case, the corresponding OperatorInfo is not created, create the new one.
|
||||
auto operator_info = CreateTheOperatorInfo(prim, cnode);
|
||||
if (operator_info == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
// Needed by rec_parser
|
||||
operator_info->set_type(prim->name());
|
||||
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
|
||||
|
||||
entire_costgraph->AddOperator(operator_info);
|
||||
(void)cnode->set_operator_info(operator_info);
|
||||
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
|
||||
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
|
||||
// Needed by rec_parser
|
||||
entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
|
||||
} else {
|
||||
auto current_op_ptr = search_cnode->second;
|
||||
if (current_op_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
|
||||
} else {
|
||||
|
@ -484,14 +533,12 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
|
|||
<< " does not match the Prim: " << prim->name();
|
||||
}
|
||||
(void)cnode->set_operator_info(current_op_ptr);
|
||||
current_op_index++;
|
||||
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((!new_operator) && (current_op_index != entire_costgraph->GetOperators().size())) {
|
||||
MS_LOG(EXCEPTION) << "The second subgraph's operator number: " << current_op_index
|
||||
<< " does not match the first ones: " << entire_costgraph->GetOperators().size();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
|
||||
return SUCCESS;
|
||||
|
@ -844,11 +891,20 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
|||
// OUTPUT: the determined strategy for each operator.
|
||||
|
||||
// Step 1
|
||||
if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
|
||||
<< " operators.";
|
||||
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
|
||||
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
|
||||
<< entire_costgraph->GetOperators().size() << " operators.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
|
||||
if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
|
||||
<< entire_costgraph->GetOperators().size() << " operators.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2
|
||||
|
@ -916,7 +972,7 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
|
|||
}
|
||||
|
||||
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) {
|
||||
if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
|
||||
<< " operators.";
|
||||
} else {
|
||||
|
|
|
@ -43,7 +43,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node);
|
|||
|
||||
std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node);
|
||||
|
||||
Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
|
||||
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
|
||||
|
||||
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
|
||||
|
||||
void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <functional>
|
||||
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "parallel/costmodel_context.h"
|
||||
#include "pipeline/pass.h"
|
||||
#include "pipeline/parse/parse_base.h"
|
||||
#include "pipeline/parse/data_converter.h"
|
||||
|
@ -341,7 +342,10 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
|
||||
// Resolve the python func
|
||||
actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
|
||||
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||
auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
|
||||
if (!multi_graphs) {
|
||||
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
||||
}
|
||||
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
||||
// Evaluate type and shape, and specialize
|
||||
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
||||
|
|
|
@ -222,6 +222,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set the parameter cost_model_communi_bias of the DP algorithm.")
|
||||
.def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias,
|
||||
"Get the parameter cost_model_communi_bias of the DP algorithm.")
|
||||
.def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
|
||||
.def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
|
||||
.def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
|
||||
"Set the parameter gradient AllReduce fusion algorithm.")
|
||||
.def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,
|
||||
|
|
|
@ -214,6 +214,31 @@ class _CostModelContext:
|
|||
raise ValueError("Context handle is none in context!!!")
|
||||
return self._context_handle.get_costmodel_communi_bias()
|
||||
|
||||
def set_multi_subgraphs(self, multi_subgraph):
|
||||
"""
|
||||
Set the flag of ANF graph containing multiple subgraphs.
|
||||
|
||||
Args:
|
||||
multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
|
||||
|
||||
Raises:
|
||||
ValueError: If context handle is none.
|
||||
"""
|
||||
if self._context_handle is None:
|
||||
raise ValueError("Context handle is none in context!!!")
|
||||
self._context_handle.set_multi_subgraphs(multi_subgraph)
|
||||
|
||||
def get_multi_subgraphs(self):
|
||||
"""
|
||||
Get the flag of ANF graph containing multiple subgraphs.
|
||||
|
||||
Raises:
|
||||
ValueError: If context handle is none.
|
||||
"""
|
||||
if self._context_handle is None:
|
||||
raise ValueError("Context handle is none in context!!!")
|
||||
return self._context_handle.get_multi_subgraphs()
|
||||
|
||||
def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
|
||||
"""
|
||||
Set costmodel allreduce fusion algorithm.
|
||||
|
@ -427,6 +452,7 @@ set_cost_model_context_func_map = {
|
|||
"costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
|
||||
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
|
||||
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
|
||||
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
|
||||
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
|
||||
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
|
||||
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
|
||||
|
@ -447,6 +473,7 @@ get_cost_model_context_func_map = {
|
|||
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
|
||||
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
|
||||
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
|
||||
"multi_subgraphs": cost_model_context().get_multi_subgraphs(),
|
||||
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
|
||||
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
|
||||
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
|
||||
|
@ -461,6 +488,7 @@ get_cost_model_context_func_map = {
|
|||
|
||||
@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
|
||||
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
|
||||
multi_subgraphs=bool,
|
||||
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
|
||||
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
|
||||
costmodel_allreduce_fusion_allreduce_inherent_time=float,
|
||||
|
@ -481,6 +509,7 @@ def set_cost_model_context(**kwargs):
|
|||
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
|
||||
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
|
||||
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
|
||||
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
|
||||
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
|
||||
0: bypass allreduce fusion;
|
||||
1: only use backward computation time to group allreduce;
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
import numpy as np
|
||||
from mindspore import context
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.optim import Adam, FTRL
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters
|
||||
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
self.relu = P.ReLU()
|
||||
self.wd = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide")
|
||||
self.wt = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="l")
|
||||
def construct(self, x):
|
||||
out = self.mul(x, self.wd)
|
||||
out = self.mul(out, self.wt)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.sum = P.ReduceSum()
|
||||
self.mean = P.ReduceMean()
|
||||
self.net = network
|
||||
|
||||
def construct(self, x):
|
||||
predict = self.net(x)
|
||||
loss1 = self.sum(predict, -1)
|
||||
loss2 = self.mean(predict, -1)
|
||||
return loss1, loss2
|
||||
|
||||
class IthOutputCell(nn.Cell):
|
||||
def __init__(self, network, output_index):
|
||||
super(IthOutputCell, self).__init__()
|
||||
self.network = network
|
||||
self.output_index = output_index
|
||||
|
||||
def construct(self, x):
|
||||
predict = self.network(x)[self.output_index]
|
||||
return predict
|
||||
|
||||
class TrainStepWarp(nn.Cell):
|
||||
def __init__(self, network, sens=1000.0):
|
||||
super(TrainStepWarp, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
self.trainable_params = network.trainable_params()
|
||||
weights_w = []
|
||||
weights_d = []
|
||||
for params in self.trainable_params:
|
||||
weights_w.append(params)
|
||||
weights_d.append(params)
|
||||
self.weights_w = ParameterTuple(weights_w)
|
||||
self.weights_d = ParameterTuple(weights_d)
|
||||
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, l1=1e-8,
|
||||
l2=1e-8, initial_accum=1.0)
|
||||
self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8,
|
||||
loss_scale=sens)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.grad_w = C.GradOperation('grad_w', get_by_list=True, sens_param=True)
|
||||
self.grad_d = C.GradOperation('grad_d', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.loss_net_w = IthOutputCell(network, output_index=0)
|
||||
self.loss_net_d = IthOutputCell(network, output_index=1)
|
||||
|
||||
def construct(self, x):
|
||||
weights_w = self.weights_w
|
||||
weights_d = self.weights_d
|
||||
loss_w, loss_d = self.network(x)
|
||||
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
|
||||
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
|
||||
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
|
||||
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)
|
||||
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d))
|
||||
|
||||
def test_double_subgraphs():
|
||||
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = TrainStepWarp(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
|
||||
reset_op_id()
|
||||
_executor.compile(net, x, phase='train')
|
||||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
|
||||
assert strategies == expected_strategies
|
|
@ -0,0 +1,70 @@
|
|||
import numpy as np
|
||||
from mindspore import context
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import _executor
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
||||
import re
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
predict = self.network(x)
|
||||
return self.loss(predict)
|
||||
|
||||
class Blockcell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Blockcell, self).__init__()
|
||||
self.bn = nn.BatchNorm2d(64, momentum=0.9)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.bn(x)
|
||||
return out
|
||||
|
||||
def getBlock():
|
||||
return Blockcell()
|
||||
|
||||
def test_two_bn():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.block1 = getBlock()
|
||||
self.block2 = getBlock()
|
||||
self.relu = P.ReLU()
|
||||
self.add = P.TensorAdd()
|
||||
self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.block1(x)
|
||||
out = self.relu(out)
|
||||
out = self.add(out, self.bias)
|
||||
out = self.block2(out)
|
||||
return out
|
||||
|
||||
net = NetWithLoss(Net())
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, phase='train')
|
||||
strategies = _executor._get_strategy(net)
|
||||
assert len(strategies) == 4
|
||||
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('BatchNorm-op', k) is not None:
|
||||
assert v == [[8, 1], [1], [1], [1], [1]]
|
||||
elif re.search('TensorAdd-op', k) is not None:
|
||||
assert v == [[8, 1], [8, 1]]
|
||||
elif re.search('ReLU-op', k) is not None:
|
||||
assert v == [[8, 1]]
|
Loading…
Reference in New Issue