forked from mindspore-Ecosystem/mindspore
fix codex in tensor_layout.cc
This commit is contained in:
parent
965f46e481
commit
2a0a549ae8
|
@ -37,6 +37,7 @@ constexpr size_t SLICE_SIZE_INDEX = 2;
|
|||
constexpr size_t SLICE_INPUTS_SIZE = 3;
|
||||
constexpr size_t STRIDED_SLICE_ATTRS_SIZE = 5;
|
||||
constexpr size_t STRIDED_SLICE_INPUTS_SIZE = 4;
|
||||
constexpr size_t STRIDED_SLICE_ARGS_SIZE = 3;
|
||||
constexpr size_t STRIDED_SLICE_BEGIN_INDEX = 1;
|
||||
constexpr size_t STRIDED_SLICE_END_INDEX = 2;
|
||||
constexpr size_t STRIDED_SLICE_STRIDES_INDEX = 3;
|
||||
|
@ -57,8 +58,20 @@ constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2;
|
|||
constexpr size_t UNIQUE_INPUTS_SIZE = 1;
|
||||
constexpr size_t UNIQUE_INPUT_SIZE = 1;
|
||||
constexpr size_t UNIQUE_OUTPUTS_SIZE = 2;
|
||||
constexpr size_t TRANSFER_PERMUTE_ARGS_SIZE = 5;
|
||||
constexpr size_t TRANSFER_PERMUTE_SPLIT_COUNT_INDEX = 0;
|
||||
constexpr size_t TRANSFER_PERMUTE_SPLIT_DIM_INDEX = 1;
|
||||
constexpr size_t TRANSFER_PERMUTE_CONCAT_DIM_INDEX = 2;
|
||||
constexpr size_t TRANSFER_PERMUTE_DEV_DIM_INDEX = 3;
|
||||
constexpr size_t TRANSFER_PERMUTE_DEV_NUM_INDEX = 4;
|
||||
constexpr size_t TRANSFER_CONCAT_ARGS_SIZE = 3;
|
||||
constexpr size_t TRANSFER_CONCAT_TENSOR_DIM_INDEX = 0;
|
||||
constexpr size_t TRANSFER_CONCAT_DEV_DIM_INDEX = 1;
|
||||
constexpr size_t TRANSFER_CONCAT_SPLIT_COUNT_INDEX = 2;
|
||||
constexpr size_t TRANSFER_SPLIT_ARGS_SIZE = 3;
|
||||
constexpr double EPS = 1e-6;
|
||||
constexpr double INF = 1e20;
|
||||
constexpr double COST_FACTOR = 2.0;
|
||||
|
||||
constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only";
|
||||
constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only";
|
||||
|
|
|
@ -25,7 +25,7 @@ std::string Array::ToString() const {
|
|||
std::ostringstream buffer;
|
||||
buffer << "[ ";
|
||||
for (auto &element : array_) {
|
||||
buffer << std::to_string(element) + " ";
|
||||
buffer << (std::to_string(element) + " ");
|
||||
}
|
||||
buffer << "]";
|
||||
return buffer.str();
|
||||
|
|
|
@ -70,12 +70,12 @@ Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &en
|
|||
OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask};
|
||||
|
||||
ValuePtr param_begin_value = MakeValue(begin);
|
||||
Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2);
|
||||
Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), STRIDED_SLICE_BEGIN_INDEX + 1);
|
||||
ValuePtr param_end_value = MakeValue(end);
|
||||
Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3);
|
||||
Param param_end = std::make_pair(std::make_pair(END, param_end_value), STRIDED_SLICE_END_INDEX + 1);
|
||||
|
||||
ValuePtr param_strides_value = MakeValue(strides);
|
||||
Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4);
|
||||
Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), STRIDED_SLICE_STRIDES_INDEX + 1);
|
||||
OperatorParams params = {param_begin, param_end, param_strides};
|
||||
OperatorArgs op_args = std::make_pair(attrs, params);
|
||||
|
||||
|
@ -83,17 +83,17 @@ Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &en
|
|||
}
|
||||
|
||||
Status ConstructOperator::StridedSliceOP(Args args) {
|
||||
if (args.size() < 3) {
|
||||
if (args.size() < STRIDED_SLICE_ARGS_SIZE) {
|
||||
MS_LOG(ERROR) << "args size should not be less than 3!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
int64_t split_count = args[0];
|
||||
int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
|
||||
if (split_count <= 0) {
|
||||
MS_LOG(ERROR) << "split_count should not be less than 0!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
int64_t split_dim = args[1];
|
||||
int64_t dev_dim = args[2];
|
||||
int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
|
||||
int64_t dev_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
|
||||
std::vector<Group> group_list;
|
||||
|
||||
if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
|
||||
|
@ -193,14 +193,14 @@ Status ConstructOperator::SplitOP(int64_t split_count) {
|
|||
}
|
||||
|
||||
Status ConstructOperator::AlltoAllOP(Args args) {
|
||||
if (args.size() < 4) {
|
||||
MS_LOG(ERROR) << "args size should not be less than 4!";
|
||||
if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
|
||||
MS_LOG(ERROR) << "args size should not be less than 5!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
int64_t split_count = args[0];
|
||||
int64_t split_dim = args[1];
|
||||
int64_t concat_dim = args[2];
|
||||
int64_t dev_dim = args[3];
|
||||
int64_t split_count = args[TRANSFER_PERMUTE_SPLIT_COUNT_INDEX];
|
||||
int64_t split_dim = args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX];
|
||||
int64_t concat_dim = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
|
||||
int64_t dev_dim = args[TRANSFER_PERMUTE_DEV_DIM_INDEX];
|
||||
if (split_count <= 0) {
|
||||
MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!";
|
||||
return Status::FAILED;
|
||||
|
|
|
@ -211,26 +211,26 @@ Status RedistributionOperatorInfer::InsertOperator(OperatorName name, Args args)
|
|||
}
|
||||
|
||||
Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) {
|
||||
if (args.size() < 3) {
|
||||
if (args.size() < TRANSFER_SPLIT_ARGS_SIZE) {
|
||||
MS_LOG(ERROR) << "args size should not be less than 3!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
size_t index = LongToSize(args[1]);
|
||||
size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
|
||||
if (constructor_.StridedSliceOP(args) != Status::SUCCESS) {
|
||||
return Status::FAILED;
|
||||
} else {
|
||||
operator_vector_.push_back(constructor_.GetOperator());
|
||||
output_info_vector_.push_back(std::make_pair(false, 0));
|
||||
}
|
||||
if (cur_tensor_layout_.UpdateTensorMap(index, args[2]) == Status::FAILED) {
|
||||
if (cur_tensor_layout_.UpdateTensorMap(index, args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX]) == Status::FAILED) {
|
||||
return Status::FAILED;
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) {
|
||||
if (args.size() < 3) {
|
||||
MS_LOG(ERROR) << "args size should not be less than 3!";
|
||||
if (args.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
|
||||
MS_LOG(ERROR) << "args size should not be less than 5!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
if (constructor_.AlltoAllOP(args) != Status::SUCCESS) {
|
||||
|
@ -239,8 +239,8 @@ Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) {
|
|||
operator_vector_.push_back(constructor_.GetOperator());
|
||||
output_info_vector_.push_back(std::make_pair(false, 0));
|
||||
}
|
||||
size_t index = LongToSize(args[1]);
|
||||
int64_t val = args[2];
|
||||
size_t index = LongToSize(args[TRANSFER_PERMUTE_SPLIT_DIM_INDEX]);
|
||||
int64_t val = args[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
|
||||
int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
|
||||
|
||||
if (cur_tensor_layout_.UpdateTensorMap(LongToSize(val), NONE) == Status::FAILED) {
|
||||
|
@ -253,13 +253,13 @@ Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) {
|
|||
}
|
||||
|
||||
Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) {
|
||||
if (args.size() < 3) {
|
||||
if (args.size() < TRANSFER_CONCAT_ARGS_SIZE) {
|
||||
MS_LOG(ERROR) << "args size should not be less than 3!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
int64_t tensor_dim = args[0];
|
||||
int64_t dev_dim = args[1];
|
||||
int64_t split_count = args[2];
|
||||
int64_t tensor_dim = args[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
|
||||
int64_t dev_dim = args[TRANSFER_CONCAT_DEV_DIM_INDEX];
|
||||
int64_t split_count = args[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
|
||||
if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) {
|
||||
return Status::FAILED;
|
||||
} else {
|
||||
|
|
|
@ -365,23 +365,24 @@ bool TensorLayout::operator!=(const TensorLayout &t1) const {
|
|||
/*
|
||||
* remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ]
|
||||
* example 1:
|
||||
* original tensor layout:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ 0 -1 -1 -1 ]
|
||||
* tensor shape = [ 128 64 1 1 ]
|
||||
* return tensor layout:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ 0 -1 ]
|
||||
* tensor shape = [ 128 64 ]
|
||||
* original tensor layout:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ 0 -1 -1 -1 ]
|
||||
* tensor shape = [ 128 64 1 1 ]
|
||||
* return tensor layout:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ 0 -1 ]
|
||||
* tensor shape = [ 128 64 ]
|
||||
*
|
||||
* example 2:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ -1 -1 -1 -1 ]
|
||||
* tensor shape = [ 1 1 1 1 ]
|
||||
* return tensor layout:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ -1 ]
|
||||
* tensor shape = [ 1 ]
|
||||
* original tensor layout:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ -1 -1 -1 -1 ]
|
||||
* tensor shape = [ 1 1 1 1 ]
|
||||
* return tensor layout:
|
||||
* device arrangement = [ 8 ]
|
||||
* tensor map = [ -1 ]
|
||||
* tensor shape = [ 1 ]
|
||||
*/
|
||||
TensorLayout TensorLayout::SqueezeShape() const {
|
||||
TensorLayout out;
|
||||
|
|
|
@ -219,7 +219,7 @@ Status TensorRedistribution::ComputeCost() {
|
|||
// There is only computation cost in SplitByAxis.
|
||||
// computation cost = before_slice_shape
|
||||
computation_cost_ += prod;
|
||||
// This addtion may be erroneous
|
||||
// This addition may be erroneous
|
||||
memory_cost_ += prod;
|
||||
}
|
||||
}
|
||||
|
@ -232,8 +232,8 @@ Status TensorRedistribution::ComputeCost() {
|
|||
}
|
||||
double prev_prod =
|
||||
std::accumulate(prev_shape.begin(), prev_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
|
||||
computation_cost_ += 2.0 * prev_prod;
|
||||
memory_cost_ += 2.0 * prev_prod;
|
||||
computation_cost_ += COST_FACTOR * prev_prod;
|
||||
memory_cost_ += COST_FACTOR * prev_prod;
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
@ -241,21 +241,21 @@ Status TensorRedistribution::ComputeCost() {
|
|||
Status TensorRedistribution::ComputePermuteCost(double input_size, Shape attrs) {
|
||||
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
|
||||
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
||||
if (attrs.size() < 4) {
|
||||
MS_LOG(ERROR) << "attrs size should not be less than 4!";
|
||||
if (attrs.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
|
||||
MS_LOG(ERROR) << "attrs size should not be less than 5!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
|
||||
backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
|
||||
comm_cost_ += 2.0 * input_size * ALLTOALL_SCALE_FACTOR;
|
||||
int32_t concat_dim = attrs[2];
|
||||
comm_cost_ += COST_FACTOR * input_size * ALLTOALL_SCALE_FACTOR;
|
||||
int32_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
|
||||
if (concat_dim == 0) {
|
||||
// memory cost = all_gather
|
||||
computation_cost_ += input_size;
|
||||
memory_cost_ += input_size;
|
||||
} else {
|
||||
// memory cost = all_gather + split + concat
|
||||
int32_t dev_num = attrs[4];
|
||||
int32_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
|
||||
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
|
||||
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
|
||||
}
|
||||
|
@ -265,16 +265,16 @@ Status TensorRedistribution::ComputePermuteCost(double input_size, Shape attrs)
|
|||
Status TensorRedistribution::ComputeConcatCost(double input_size, Shape attrs) {
|
||||
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
||||
// computation cost = before_slice_shape
|
||||
if (attrs.size() < 3) {
|
||||
if (attrs.size() < TRANSFER_CONCAT_ARGS_SIZE) {
|
||||
MS_LOG(ERROR) << "op.second size should not be less than 3!";
|
||||
return Status::FAILED;
|
||||
}
|
||||
double dev_num = attrs[2];
|
||||
double dev_num = attrs[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
|
||||
// here, communication cost = all_gather + reduce_scatter
|
||||
forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
||||
int32_t concat_dim = attrs[0];
|
||||
int32_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
|
||||
if (concat_dim == 0) {
|
||||
// computation cost = all_gather
|
||||
computation_cost_ += input_size;
|
||||
|
|
|
@ -131,7 +131,8 @@ TEST_F(TestConstructOperator, TestAlltoAllOP) {
|
|||
int64_t split_dim = 0;
|
||||
int64_t concat_dim = 1;
|
||||
int64_t dev_dim = 3;
|
||||
Args args = {split_count, split_dim, concat_dim, dev_dim};
|
||||
int64_t dev_num = 8;
|
||||
Args args = {split_count, split_dim, concat_dim, dev_dim, dev_num};
|
||||
ASSERT_EQ(constructor.AlltoAllOP(args), Status::SUCCESS);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue