!24568 Apply batch parallel in auto_parallel mode when strategies are not specified

Merge pull request !24568 from zhuyuxiao/master
This commit is contained in:
i-robot 2021-10-13 01:20:15 +00:00 committed by Gitee
commit 3fd94000c5
6 changed files with 229 additions and 54 deletions

View File

@ -92,6 +92,7 @@ constexpr char STRATEGY[] = "strategy";
constexpr char STAGE_ATTR[] = "stage";
constexpr char GEN_STRATEGY[] = "gen_strategy";
constexpr char REDUCE_OP_SUM[] = "sum";
constexpr char STRATEGY_GEN_MODE[] = "strategy_gen_mode";
constexpr char REDUCE_OP_MAX[] = "max";
constexpr char REDUCE_OP_MIN[] = "min";
constexpr char REDUCE_OP_ANY[] = "any";

View File

@ -257,35 +257,40 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
} else {
strategyPtr = (*stra_map)[strategy_key_name];
}
if (strategyPtr != nullptr) {
if (prim->name() == RESHAPE) {
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
}
const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
// Set cost for this configured strategy
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
} else if (fully_use_devices) {
// If configured to fully use devices, then checking for the user-specified strategy
int64_t used_devices = operator_info->used_devices();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
if (used_devices == 1) {
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
return;
}
// 'used_devices == -1' means that 'used_devices_' is not set
if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) {
MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
<< "but the specified strategy uses device: " << used_devices
<< ", total devices: " << total_device_num
<< ", try to set 'set_algo_parameters(fully_use_devices=False)' "
"in package 'mindspore.parallel'.";
}
}
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
if (strategyPtr == nullptr) {
return;
}
if (prim->name() == RESHAPE) {
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
return;
}
// Set cost for this configured strategy
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
return;
}
const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
if (fully_use_devices) {
// If configured to fully use devices, then checking for the user-specified strategy
int64_t used_devices = operator_info->used_devices();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
// 'used_devices == -1' means that 'used_devices_' is not set
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
if (used_devices == -1 || (used_devices != 1 && LongToSize(used_devices) != total_device_num)) {
MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
<< "but the specified strategy uses device: " << used_devices
<< ", total devices: " << total_device_num
<< ", try to set 'set_algo_parameters(fully_use_devices=False)' "
"in package 'mindspore.parallel'.";
}
}
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
}
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
@ -346,31 +351,43 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
// If no strategy has been configured for this operator, then candidate strategies are generated for
// auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
// if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList();
if (operator_info->GenerateStrategies(0) != SUCCESS) {
MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
return nullptr;
}
if (ParallelContext::GetInstance()->sharding_propagation() &&
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
const auto &swc_vec = operator_info->GetStrategyCost();
if (swc_vec.empty()) {
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
}
MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
}
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
operator_info->ApproximateStrategies();
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
}
} else {
if ((StrategyFound(attrs) && prim->name() != CAST) || load_strategy_from_ckpt) {
SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name);
return operator_info;
}
// Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
// BatchParallelInfo operator
operator_info->ComputeBatchSplitFlagList();
bool retGenStra;
if (AttrFound(attrs, STRATEGY_GEN_MODE) && GetValue<std::string>(attrs[STRATEGY_GEN_MODE]) == "batch") {
MS_LOG(INFO) << "generating batch parallel strategy...";
StrategyPtr strategyPtr = parallel::GenerateBatchParallelStrategy(operator_info, prim);
retGenStra = operator_info->SetCostUnderStrategy(strategyPtr);
} else {
MS_LOG(INFO) << "auto-searching strategy...";
retGenStra = operator_info->GenerateStrategies(0);
}
if (retGenStra != SUCCESS) {
MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
return nullptr;
}
if (ParallelContext::GetInstance()->sharding_propagation() &&
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
const auto &swc_vec = operator_info->GetStrategyCost();
if (swc_vec.empty()) {
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
}
MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
}
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
operator_info->ApproximateStrategies();
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
}
return operator_info;
}

View File

@ -483,11 +483,16 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
}
}
bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs) {
bool StrategyFound(const std::unordered_map<std::string, ValuePtr> &attrs) {
auto iter = attrs.find(STRATEGY);
return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
}
bool AttrFound(const std::unordered_map<std::string, ValuePtr> &attrs, const std::string &target) {
auto iter = attrs.find(target);
return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
}
bool HasStrategy(const FuncGraphPtr &root) {
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);

View File

@ -69,7 +69,9 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
const CNodePtr &middle_node, int64_t index, TensorRedistribution tensor_redistribution,
const CNodePtr &pre_node);
bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs);
bool StrategyFound(const std::unordered_map<std::string, ValuePtr> &attrs);
bool AttrFound(const std::unordered_map<std::string, ValuePtr> &attrs, const std::string &target);
void MarkForwardCNode(const FuncGraphPtr &root);

View File

@ -377,6 +377,33 @@ class Cell(Cell_):
f"The function construct needs {positional_args} positional argument and {default_args} default "
f"argument, but provided {len(inputs)}")
def _get_prims_recursively(self):
all_prims = list()
for _, value in self._primitives.items():
if value:
all_prims.append(value)
for cell in self.cells():
all_prims.extend(cell._get_prims_recursively())
return all_prims
def set_strategy_gen_mode(self, mode):
"""
while using auto_parallel_context = ParallelMode.AUTO_PARALLEL, if this method is applied, then
1. mode = "batch":
for all primitive ops in this cell(including ops of cells that wrapped by this cell),
if parallel strategy is not specified, then instead of auto-searching,
batch parallel strategy will be generated for those primitive ops.
"""
strategy_gen_modes = ["batch"]
if mode not in strategy_gen_modes:
raise AssertionError(f"unexpected input {mode}, must be one of {strategy_gen_modes}")
all_prims = self._get_prims_recursively()
for prim in all_prims:
prim.add_prim_attr("strategy_gen_mode", mode)
class CellGuard:
def __enter__(self):
_pynative_executor.set_lazy_build(True)

View File

@ -0,0 +1,123 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
from mindspore import nn, context, Tensor
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import operations as P
class NetMul(nn.Cell):
def __init__(self, strategy=None):
super().__init__()
self.mul = P.Mul().shard(strategy)
def construct(self, x, y):
return self.mul(x, y)
class NetMatMul(nn.Cell):
def __init__(self, strategy=None):
super().__init__()
self.matmul = P.MatMul().shard(strategy)
def construct(self, x, y):
return self.matmul(x, y)
class NetRecursive(nn.Cell):
def __init__(self):
super().__init__()
self.mul_net = NetMul()
self.matmul_net = NetMatMul()
def construct(self, x, y):
out1 = self.matmul_net(x, y)
out2 = self.matmul_net(x, y)
return self.mul_net(out1, out2)
def compile_net(net, x, y):
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x, y)
def test_batch_parallel_matmul():
"""
Feature: strategy gen mode
Description: test batch matmul
Expectation: using batch parallel mode to generate unspecified strategies in primitive ops
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net = NetMatMul()
net.set_strategy_gen_mode("batch")
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
compile_net(net, x, y)
def test_batch_parallel_mul():
"""
Feature: strategy gen mode
Description: test mul
Expectation: using batch parallel mode to generate unspecified strategies in primitive ops
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net = NetMatMul()
net.set_strategy_gen_mode("batch")
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
y = Tensor(np.ones([128, 128]), dtype=ms.float32)
compile_net(net, x, y)
def test_batch_parallel_recursive():
"""
Feature: strategy gen mode
Description: test primitive ops in cells wrapped by other cells
Expectation: using batch parallel mode to generate unspecified strategies in primitive ops
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net = NetRecursive()
net.set_strategy_gen_mode("batch")
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
compile_net(net, x, y)
def test_batch_parallel_with_user_strategy():
"""
Feature: strategy gen mode
Description: test strategy gen mode while users have specified strategies
Expectation: for those primitive ops who have users specified strategies, using those strategies;
for those who do not, using batch parallel mode to generate strategies
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net = NetMatMul(strategy=((1, 8), (8, 1)))
net.set_strategy_gen_mode("batch")
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
compile_net(net, x, y)