!16876 code clean for master

From: @alouhahahahaha
Reviewed-by: @zhoufeng54,@jjfeing,@kisnwang
Signed-off-by: @jjfeing,@kisnwang
This commit is contained in:
mindspore-ci-bot 2021-05-27 09:00:58 +08:00 committed by Gitee
commit 3fde1b74ec
5 changed files with 18 additions and 14 deletions

View File

@ -169,7 +169,7 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
}
ulong loop_size = hccl_data_type_list_.size();
if (AnfAlgo::GetInputTensorNum(anf_node) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
loop_size *= rank_size;
loop_size *= static_cast<ulong>(rank_size);
}
if (op_name == kReduceScatterOpName && fusion >= 1) {
loop_size = AnfAlgo::GetOutputTensorNum(anf_node);

View File

@ -137,11 +137,11 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
MS_LOG(ERROR) << "Get rank size failed";
return false;
}
int64_t actual_input_size = input_size;
size_t actual_input_size = input_size;
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
actual_input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
}
block_size = actual_input_size / LongToSize(rank_size);
block_size = static_cast<uint64_t>(actual_input_size / LongToSize(rank_size));
total_size = total_size + block_size;
} else {
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {

View File

@ -95,7 +95,7 @@ AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePt
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < inputs_size; ++i) {
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) {
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
concat_inputs.push_back(new_tuple_getitems[idx]);
}
auto concat = func_graph->NewCNode(concat_inputs);

View File

@ -26,6 +26,7 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F
MS_EXCEPTION_IF_NULL(func_graph);
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
std::vector<AnfNodePtr> split_outputs;
size_t rank_size_t = LongToSize(rank_size);
for (size_t i = 0; i < inputs_size; i++) {
std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
@ -34,16 +35,16 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F
std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
std::vector<std::vector<size_t>> shapes;
std::vector<int> size_splits;
for (size_t j = 0; j < IntToSize(rank_size); j++) {
for (size_t j = 0; j < rank_size_t; j++) {
std::vector<size_t> output_node_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
output_node_shape[0] /= rank_size;
output_node_shape[0] /= rank_size_t;
shapes.push_back(output_node_shape);
size_splits.push_back(output_node_shape[0]);
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0L), split);
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToInt(rank_size)), split);
AnfAlgo::SetNodeAttr("num_split", MakeValue(rank_size_t), split);
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split);
kernel_select_->SelectKernel(split);
std::vector<AnfNodePtr> new_outputs;
@ -63,8 +64,9 @@ AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const Fu
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
std::vector<AnfNodePtr> reduce_scatter_inputs{
NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceScatter->name()))};
for (size_t i = 0; i < IntToSize(rank_size); i++) {
for (size_t j = 0, idx = i; j < inputs_size; j++, idx += IntToSize(rank_size)) {
size_t rank_size_t = LongToSize(rank_size);
for (size_t i = 0; i < rank_size_t; i++) {
for (size_t j = 0, idx = i; j < inputs_size; j++, idx += rank_size_t) {
reduce_scatter_inputs.push_back(inputs[idx]);
}
}

View File

@ -50,20 +50,21 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
}
size_t rank_size_t = LongToSize(rank_size);
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
}
for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) {
for (size_t rank_index = 0; rank_index < rank_size_t; ++rank_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
if (!shape.empty()) {
shape[0] /= rank_size;
shape[0] /= rank_size_t;
}
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
}
@ -315,16 +316,17 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
}
size_t output_num = node_num * rank_size;
size_t rank_size_t = LongToSize(rank_size);
size_t output_num = node_num * rank_size_t;
std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0));
std::vector<std::vector<size_t>> shapes;
for (size_t i = 0; i < IntToSize(rank_size); ++i) {
for (size_t i = 0; i < rank_size_t; ++i) {
for (size_t idx = start_index; idx <= end_index; ++idx) {
auto input_cnode = communication_op_info.communication_op_nodes[idx];
MS_EXCEPTION_IF_NULL(input_cnode);
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(input_cnode, 0);
if (!shape.empty()) {
shape[0] /= rank_size;
shape[0] /= rank_size_t;
}
shapes.push_back(shape);
}