!16876 code clean for master
From: @alouhahahahaha Reviewed-by: @zhoufeng54,@jjfeing,@kisnwang Signed-off-by: @jjfeing,@kisnwang
This commit is contained in:
commit
3fde1b74ec
|
@ -169,7 +169,7 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
|||
}
|
||||
ulong loop_size = hccl_data_type_list_.size();
|
||||
if (AnfAlgo::GetInputTensorNum(anf_node) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
|
||||
loop_size *= rank_size;
|
||||
loop_size *= static_cast<ulong>(rank_size);
|
||||
}
|
||||
if (op_name == kReduceScatterOpName && fusion >= 1) {
|
||||
loop_size = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
|
|
|
@ -137,11 +137,11 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
|
|||
MS_LOG(ERROR) << "Get rank size failed";
|
||||
return false;
|
||||
}
|
||||
int64_t actual_input_size = input_size;
|
||||
size_t actual_input_size = input_size;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
|
||||
actual_input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
|
||||
}
|
||||
block_size = actual_input_size / LongToSize(rank_size);
|
||||
block_size = static_cast<uint64_t>(actual_input_size / LongToSize(rank_size));
|
||||
total_size = total_size + block_size;
|
||||
} else {
|
||||
if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {
|
||||
|
|
|
@ -95,7 +95,7 @@ AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < inputs_size; ++i) {
|
||||
std::vector<AnfNodePtr> concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) {
|
||||
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
|
||||
concat_inputs.push_back(new_tuple_getitems[idx]);
|
||||
}
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
|
|
|
@ -26,6 +26,7 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
std::vector<AnfNodePtr> split_outputs;
|
||||
size_t rank_size_t = LongToSize(rank_size);
|
||||
for (size_t i = 0; i < inputs_size; i++) {
|
||||
std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
|
||||
|
@ -34,16 +35,16 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F
|
|||
std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
std::vector<int> size_splits;
|
||||
for (size_t j = 0; j < IntToSize(rank_size); j++) {
|
||||
for (size_t j = 0; j < rank_size_t; j++) {
|
||||
std::vector<size_t> output_node_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
|
||||
output_node_shape[0] /= rank_size;
|
||||
output_node_shape[0] /= rank_size_t;
|
||||
shapes.push_back(output_node_shape);
|
||||
size_splits.push_back(output_node_shape[0]);
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
|
||||
|
||||
AnfAlgo::SetNodeAttr("split_dim", MakeValue(0L), split);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToInt(rank_size)), split);
|
||||
AnfAlgo::SetNodeAttr("num_split", MakeValue(rank_size_t), split);
|
||||
AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split);
|
||||
kernel_select_->SelectKernel(split);
|
||||
std::vector<AnfNodePtr> new_outputs;
|
||||
|
@ -63,8 +64,9 @@ AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const Fu
|
|||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
std::vector<AnfNodePtr> reduce_scatter_inputs{
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceScatter->name()))};
|
||||
for (size_t i = 0; i < IntToSize(rank_size); i++) {
|
||||
for (size_t j = 0, idx = i; j < inputs_size; j++, idx += IntToSize(rank_size)) {
|
||||
size_t rank_size_t = LongToSize(rank_size);
|
||||
for (size_t i = 0; i < rank_size_t; i++) {
|
||||
for (size_t j = 0, idx = i; j < inputs_size; j++, idx += rank_size_t) {
|
||||
reduce_scatter_inputs.push_back(inputs[idx]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,20 +50,21 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co
|
|||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
|
||||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
}
|
||||
size_t rank_size_t = LongToSize(rank_size);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) {
|
||||
for (size_t rank_index = 0; rank_index < rank_size_t; ++rank_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
|
||||
if (!shape.empty()) {
|
||||
shape[0] /= rank_size;
|
||||
shape[0] /= rank_size_t;
|
||||
}
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
|
@ -315,16 +316,17 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
|
||||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
|
||||
}
|
||||
size_t output_num = node_num * rank_size;
|
||||
size_t rank_size_t = LongToSize(rank_size);
|
||||
size_t output_num = node_num * rank_size_t;
|
||||
std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0));
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (size_t i = 0; i < IntToSize(rank_size); ++i) {
|
||||
for (size_t i = 0; i < rank_size_t; ++i) {
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto input_cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(input_cnode, 0);
|
||||
if (!shape.empty()) {
|
||||
shape[0] /= rank_size;
|
||||
shape[0] /= rank_size_t;
|
||||
}
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue