forked from mindspore-Ecosystem/mindspore
!24568 Apply batch parallel in auto_parallel mode when strategies are not specified
Merge pull request !24568 from zhuyuxiao/master
This commit is contained in:
commit
3fd94000c5
|
@ -92,6 +92,7 @@ constexpr char STRATEGY[] = "strategy";
|
|||
constexpr char STAGE_ATTR[] = "stage";
|
||||
constexpr char GEN_STRATEGY[] = "gen_strategy";
|
||||
constexpr char REDUCE_OP_SUM[] = "sum";
|
||||
constexpr char STRATEGY_GEN_MODE[] = "strategy_gen_mode";
|
||||
constexpr char REDUCE_OP_MAX[] = "max";
|
||||
constexpr char REDUCE_OP_MIN[] = "min";
|
||||
constexpr char REDUCE_OP_ANY[] = "any";
|
||||
|
|
|
@ -257,35 +257,40 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
|
|||
} else {
|
||||
strategyPtr = (*stra_map)[strategy_key_name];
|
||||
}
|
||||
if (strategyPtr != nullptr) {
|
||||
if (prim->name() == RESHAPE) {
|
||||
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
|
||||
}
|
||||
const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
|
||||
// Set cost for this configured strategy
|
||||
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
|
||||
} else if (fully_use_devices) {
|
||||
// If configured to fully use devices, then checking for the user-specified strategy
|
||||
int64_t used_devices = operator_info->used_devices();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
|
||||
if (used_devices == 1) {
|
||||
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
|
||||
return;
|
||||
}
|
||||
// 'used_devices == -1' means that 'used_devices_' is not set
|
||||
if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) {
|
||||
MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
|
||||
<< "but the specified strategy uses device: " << used_devices
|
||||
<< ", total devices: " << total_device_num
|
||||
<< ", try to set 'set_algo_parameters(fully_use_devices=False)' "
|
||||
"in package 'mindspore.parallel'.";
|
||||
}
|
||||
}
|
||||
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
|
||||
|
||||
if (strategyPtr == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (prim->name() == RESHAPE) {
|
||||
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
|
||||
return;
|
||||
}
|
||||
|
||||
// Set cost for this configured strategy
|
||||
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
|
||||
return;
|
||||
}
|
||||
|
||||
const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
|
||||
if (fully_use_devices) {
|
||||
// If configured to fully use devices, then checking for the user-specified strategy
|
||||
int64_t used_devices = operator_info->used_devices();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
|
||||
// 'used_devices == -1' means that 'used_devices_' is not set
|
||||
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
|
||||
if (used_devices == -1 || (used_devices != 1 && LongToSize(used_devices) != total_device_num)) {
|
||||
MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
|
||||
<< "but the specified strategy uses device: " << used_devices
|
||||
<< ", total devices: " << total_device_num
|
||||
<< ", try to set 'set_algo_parameters(fully_use_devices=False)' "
|
||||
"in package 'mindspore.parallel'.";
|
||||
}
|
||||
}
|
||||
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
|
||||
}
|
||||
|
||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
|
||||
|
@ -346,31 +351,43 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
// If no strategy has been configured for this operator, then candidate strategies are generated for
|
||||
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
|
||||
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
|
||||
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
|
||||
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
|
||||
// BatchParallelInfo operator
|
||||
operator_info->ComputeBatchSplitFlagList();
|
||||
if (operator_info->GenerateStrategies(0) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (ParallelContext::GetInstance()->sharding_propagation() &&
|
||||
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
|
||||
const auto &swc_vec = operator_info->GetStrategyCost();
|
||||
if (swc_vec.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
|
||||
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
|
||||
}
|
||||
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
|
||||
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
if (approximation) {
|
||||
operator_info->ApproximateStrategies();
|
||||
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
|
||||
}
|
||||
} else {
|
||||
if ((StrategyFound(attrs) && prim->name() != CAST) || load_strategy_from_ckpt) {
|
||||
SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name);
|
||||
return operator_info;
|
||||
}
|
||||
|
||||
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
|
||||
// BatchParallelInfo operator
|
||||
operator_info->ComputeBatchSplitFlagList();
|
||||
bool retGenStra;
|
||||
if (AttrFound(attrs, STRATEGY_GEN_MODE) && GetValue<std::string>(attrs[STRATEGY_GEN_MODE]) == "batch") {
|
||||
MS_LOG(INFO) << "generating batch parallel strategy...";
|
||||
StrategyPtr strategyPtr = parallel::GenerateBatchParallelStrategy(operator_info, prim);
|
||||
retGenStra = operator_info->SetCostUnderStrategy(strategyPtr);
|
||||
} else {
|
||||
MS_LOG(INFO) << "auto-searching strategy...";
|
||||
retGenStra = operator_info->GenerateStrategies(0);
|
||||
}
|
||||
|
||||
if (retGenStra != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (ParallelContext::GetInstance()->sharding_propagation() &&
|
||||
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
|
||||
const auto &swc_vec = operator_info->GetStrategyCost();
|
||||
if (swc_vec.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
|
||||
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
|
||||
}
|
||||
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
|
||||
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
if (approximation) {
|
||||
operator_info->ApproximateStrategies();
|
||||
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
|
||||
}
|
||||
return operator_info;
|
||||
}
|
||||
|
|
|
@ -483,11 +483,16 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
|
|||
}
|
||||
}
|
||||
|
||||
bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs) {
|
||||
bool StrategyFound(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
auto iter = attrs.find(STRATEGY);
|
||||
return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
|
||||
}
|
||||
|
||||
bool AttrFound(const std::unordered_map<std::string, ValuePtr> &attrs, const std::string &target) {
|
||||
auto iter = attrs.find(target);
|
||||
return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
|
||||
}
|
||||
|
||||
bool HasStrategy(const FuncGraphPtr &root) {
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
|
|
|
@ -69,7 +69,9 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
|
|||
const CNodePtr &middle_node, int64_t index, TensorRedistribution tensor_redistribution,
|
||||
const CNodePtr &pre_node);
|
||||
|
||||
bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs);
|
||||
bool StrategyFound(const std::unordered_map<std::string, ValuePtr> &attrs);
|
||||
|
||||
bool AttrFound(const std::unordered_map<std::string, ValuePtr> &attrs, const std::string &target);
|
||||
|
||||
void MarkForwardCNode(const FuncGraphPtr &root);
|
||||
|
||||
|
|
|
@ -377,6 +377,33 @@ class Cell(Cell_):
|
|||
f"The function construct needs {positional_args} positional argument and {default_args} default "
|
||||
f"argument, but provided {len(inputs)}")
|
||||
|
||||
def _get_prims_recursively(self):
|
||||
all_prims = list()
|
||||
for _, value in self._primitives.items():
|
||||
if value:
|
||||
all_prims.append(value)
|
||||
|
||||
for cell in self.cells():
|
||||
all_prims.extend(cell._get_prims_recursively())
|
||||
|
||||
return all_prims
|
||||
|
||||
def set_strategy_gen_mode(self, mode):
|
||||
"""
|
||||
while using auto_parallel_context = ParallelMode.AUTO_PARALLEL, if this method is applied, then
|
||||
1. mode = "batch":
|
||||
for all primitive ops in this cell(including ops of cells that wrapped by this cell),
|
||||
if parallel strategy is not specified, then instead of auto-searching,
|
||||
batch parallel strategy will be generated for those primitive ops.
|
||||
"""
|
||||
strategy_gen_modes = ["batch"]
|
||||
if mode not in strategy_gen_modes:
|
||||
raise AssertionError(f"unexpected input {mode}, must be one of {strategy_gen_modes}")
|
||||
|
||||
all_prims = self._get_prims_recursively()
|
||||
for prim in all_prims:
|
||||
prim.add_prim_attr("strategy_gen_mode", mode)
|
||||
|
||||
class CellGuard:
|
||||
def __enter__(self):
|
||||
_pynative_executor.set_lazy_build(True)
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import nn, context, Tensor
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetMul(nn.Cell):
|
||||
def __init__(self, strategy=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.mul(x, y)
|
||||
|
||||
|
||||
class NetMatMul(nn.Cell):
|
||||
def __init__(self, strategy=None):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.matmul(x, y)
|
||||
|
||||
class NetRecursive(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mul_net = NetMul()
|
||||
self.matmul_net = NetMatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.matmul_net(x, y)
|
||||
out2 = self.matmul_net(x, y)
|
||||
return self.mul_net(out1, out2)
|
||||
|
||||
def compile_net(net, x, y):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_batch_parallel_matmul():
|
||||
"""
|
||||
Feature: strategy gen mode
|
||||
Description: test batch matmul
|
||||
Expectation: using batch parallel mode to generate unspecified strategies in primitive ops
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net = NetMatMul()
|
||||
net.set_strategy_gen_mode("batch")
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
|
||||
compile_net(net, x, y)
|
||||
|
||||
|
||||
def test_batch_parallel_mul():
|
||||
"""
|
||||
Feature: strategy gen mode
|
||||
Description: test mul
|
||||
Expectation: using batch parallel mode to generate unspecified strategies in primitive ops
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net = NetMatMul()
|
||||
net.set_strategy_gen_mode("batch")
|
||||
|
||||
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
|
||||
compile_net(net, x, y)
|
||||
|
||||
|
||||
def test_batch_parallel_recursive():
|
||||
"""
|
||||
Feature: strategy gen mode
|
||||
Description: test primitive ops in cells wrapped by other cells
|
||||
Expectation: using batch parallel mode to generate unspecified strategies in primitive ops
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net = NetRecursive()
|
||||
net.set_strategy_gen_mode("batch")
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
|
||||
compile_net(net, x, y)
|
||||
|
||||
|
||||
def test_batch_parallel_with_user_strategy():
|
||||
"""
|
||||
Feature: strategy gen mode
|
||||
Description: test strategy gen mode while users have specified strategies
|
||||
Expectation: for those primitive ops who have users specified strategies, using those strategies;
|
||||
for those who do not, using batch parallel mode to generate strategies
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net = NetMatMul(strategy=((1, 8), (8, 1)))
|
||||
net.set_strategy_gen_mode("batch")
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
|
||||
compile_net(net, x, y)
|
Loading…
Reference in New Issue