forked from mindspore-Ecosystem/mindspore
!87 Take AllToAll as a virtual operator in cost model
Merge pull request !87 from gziyan/dev_alltoall
This commit is contained in:
commit
0cdb1171d5
|
@ -34,7 +34,7 @@ namespace parallel {
|
|||
#define OPERATOR_TO_OPERATOR_CONNECTOR "-"
|
||||
#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
|
||||
#define DEFAULT_COST_MODEL_ALPHA 1.0
|
||||
#define DEFAULT_COST_MODEL_BETA 260.0
|
||||
#define DEFAULT_COST_MODEL_BETA 400.0
|
||||
#define DEFAULT_COST_MODEL_GAMMA 0.001
|
||||
#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
|
||||
#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map,
|
||||
RankList dev_list) {
|
||||
RankList dev_list, bool is_cost_model) {
|
||||
in_tensor_map_ = tensor_layout.tensor_map();
|
||||
dev_mat_ = tensor_layout.device_arrangement();
|
||||
|
||||
|
@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons
|
|||
for (int32_t item : map) {
|
||||
map_[key++] = item;
|
||||
}
|
||||
|
||||
is_cost_model_ = is_cost_model;
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -130,15 +132,26 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() {
|
|||
std::any_of(map_.begin(), map_.end(),
|
||||
[out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) {
|
||||
int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
|
||||
Args args_allconcat = {cat_dim, out_dim, dev_mat_.GetDimByReverseIdx(IntToUint(out_dim))};
|
||||
Args args_allsplit = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim};
|
||||
if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Insert SplitByAxis Error!";
|
||||
return Status::FAILED;
|
||||
int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim));
|
||||
if (is_cost_model_) {
|
||||
int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim));
|
||||
Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim,
|
||||
dev_num};
|
||||
if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Insert PermuteByAxis Error!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
} else {
|
||||
Args args_allconcat = {cat_dim, out_dim, dev_num};
|
||||
Args args_allsplit = {dev_num, UintToInt(index), out_dim};
|
||||
if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Insert SplitByAxis Error!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
}
|
||||
(void)map_.erase(iter++);
|
||||
map_[IntToSize(cat_dim)] = NONE;
|
||||
|
|
|
@ -40,7 +40,8 @@ class RedistributionOperatorInfer {
|
|||
public:
|
||||
const int NONE = -1;
|
||||
explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {}
|
||||
Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list);
|
||||
Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list,
|
||||
bool is_cost_model = false);
|
||||
~RedistributionOperatorInfer() = default;
|
||||
OperatorList operator_list() const { return operator_list_; }
|
||||
OperatorVector operator_vector() const { return operator_vector_; }
|
||||
|
@ -67,6 +68,7 @@ class RedistributionOperatorInfer {
|
|||
ConstructOperator constructor_;
|
||||
RankList dev_list_;
|
||||
bool construct_op_flag_;
|
||||
bool is_cost_model_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout&
|
|||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() {
|
||||
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
|
||||
// Step 1: Match device arrangement between from_ and to_
|
||||
RedistributionLayoutTransfer layout_transfer;
|
||||
Status status = layout_transfer.Init(from_, to_);
|
||||
|
@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
|
|||
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
|
||||
// Step 2: Infer redistribution and insert operators
|
||||
RedistributionOperatorInfer operator_infer(construct_op_flag_);
|
||||
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) {
|
||||
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
|
||||
MS_LOG(ERROR) << "Init operatorInfer failed!";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const
|
|||
}
|
||||
|
||||
Status TensorRedistribution::ComputeCost() {
|
||||
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList();
|
||||
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
|
||||
if (redistribution_oplist_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
|
||||
return Status::FAILED;
|
||||
|
@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() {
|
|||
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
|
||||
std::string str = op.first;
|
||||
if (str == PERMUTE_BY_AXIS) {
|
||||
// The shape does not change after PermuteByAxis operation.
|
||||
// communication cost = all_to_all + all_to_all = 2 * slice_shape
|
||||
// computation cost = slice_shape
|
||||
forward_comm_cost_ += prod;
|
||||
backward_comm_cost_ += prod;
|
||||
comm_cost_ += 2.0 * prod;
|
||||
computation_cost_ += prod;
|
||||
memory_cost_ += prod;
|
||||
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
|
||||
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
||||
forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
|
||||
backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
|
||||
comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR;
|
||||
int32_t concat_dim = op.second[2];
|
||||
if (concat_dim == 0) {
|
||||
// memory cost = all_gather
|
||||
computation_cost_ += prod;
|
||||
memory_cost_ += prod;
|
||||
} else {
|
||||
// memory cost = all_gather + split + concat
|
||||
int32_t dev_num = op.second[4];
|
||||
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
|
||||
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
|
||||
}
|
||||
} else if (str == CONCAT_BY_AXIS) {
|
||||
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
||||
// computation cost = before_slice_shape
|
||||
|
@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() {
|
|||
}
|
||||
double dev_num = op.second[2];
|
||||
// here, communication cost = all_gather + reduce_scatter
|
||||
forward_comm_cost_ += prod * dev_num;
|
||||
backward_comm_cost_ += prod;
|
||||
comm_cost_ += prod * (dev_num + 1.0);
|
||||
forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
int32_t concat_dim = op.second[0];
|
||||
if (concat_dim == 0) {
|
||||
// computation cost = all_gather
|
||||
|
|
|
@ -33,6 +33,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
constexpr double ALLTOALL_SCALE_FACTOR = 2.0;
|
||||
constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5;
|
||||
class TensorRedistribution {
|
||||
public:
|
||||
explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false)
|
||||
|
@ -46,7 +48,7 @@ class TensorRedistribution {
|
|||
keep_reshape_(keep_reshape) {}
|
||||
Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list);
|
||||
~TensorRedistribution() = default;
|
||||
RedistributionOpListPtr InferTensorRedistributionOperatorList();
|
||||
RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false);
|
||||
OperatorList operator_list() const { return operator_list_; }
|
||||
bool reshape_flag() const { return reshape_flag_; }
|
||||
Status ComputeCost();
|
||||
|
|
|
@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
|
|||
|
||||
|
||||
def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192
|
||||
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0)
|
||||
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
|
||||
|
@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): #
|
|||
def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192
|
||||
dev_num = 8
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
|
||||
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0)
|
||||
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
resset_op_id()
|
||||
np.random.seed(6)
|
||||
|
|
|
@ -86,7 +86,7 @@ def test_two_matmul():
|
|||
costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha")
|
||||
assert costmodel_alpha == 1.0
|
||||
costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta")
|
||||
assert costmodel_beta == 260.0
|
||||
assert costmodel_beta == 400.0
|
||||
costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma")
|
||||
assert costmodel_gamma == 0.001
|
||||
costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold")
|
||||
|
|
Loading…
Reference in New Issue