diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc index 7b94ca5e659..f6e2ea35d62 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -39,7 +39,7 @@ std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { return kOpFormat_DEFAULT; } - if (op_name == kReceive || op_name == kHcomSend) { + if (op_name == kReceive || op_name == kHcomSend || op_name == kAllToAllv) { return kOpFormat_DEFAULT; } auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.cc index 4fec41ac07a..49b6b32c58e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.cc @@ -14,6 +14,9 @@ * limitations under the License. */ #include "backend/kernel_compiler/hccl/hcom_all_to_all.h" +#include "runtime/hccl_adapter/hccl_adapter.h" +#include "runtime/device/ascend/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" namespace mindspore::kernel { HcomAllToAllKernel::HcomAllToAllKernel() {} @@ -25,5 +28,87 @@ bool HcomAllToAllKernel::Launch(const std::vector &, const std::vect return true; } +bool HcomAllToAllKernel::Init(const AnfNodePtr &anf_node) { + bool ret = HcclKernel::Init(anf_node); + if (!ret) { + return ret; + } + + if (hccl_data_type_list_.empty()) { + auto recv_type = AnfAlgo::GetNodeAttr(anf_node, kAttrRecvType); + MS_EXCEPTION_IF_NULL(recv_type); + data_type_ = HcomUtil::ConvertHcclType(recv_type->type_id()); + } else { + data_type_ = hccl_data_type_list_[0]; + } + + workspace_size_list_ = {LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))}; + return true; +} + +const std::vector &HcomAllToAllKernel::GetOutputSizeList() const { + if (!output_size_list_.empty()) { + return output_size_list_; + } + for (size_t i = 0; i < hccl_kernel_output_shape_list_.size(); ++i) { + size_t size = 0; + if (!HcomUtil::GetHcclOpSize(data_type_, hccl_kernel_output_shape_list_[i], &size)) { + MS_LOG(EXCEPTION) << "AllToAllv get output size failed."; + } + output_size_list_.push_back(size); + } + return output_size_list_; +} + +std::vector HcomAllToAllKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + auto anf_node = anf_node_.lock(); + if (!anf_node) { + MS_LOG(EXCEPTION) << "anf_node pointer is expired."; + } + + stream_id_ = stream_id; + void *input_data_addr = inputs.empty() ? nullptr : inputs.at(0)->addr; + void *output_data_addr = outputs.empty() ? nullptr : outputs.at(0)->addr; + + std::vector private_def; + std::vector task_info; + bool ret = hccl::HcclAdapter::GetInstance().GenTask(anf_node, data_type_, &task_info); + if (!ret) { + MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed."; + } + + std::vector results; + for (auto &task : task_info) { + MS_LOG(INFO) << "AlltoAll Task : stream_id=" << stream_id << ", count=" << hccl_count_ << ", root_id=" << root_id_ + << ", op_type=" << static_cast(op_type_) << ", data_type=" << static_cast(data_type_) + << ", workspace_size=" << task.workspace_size << ", stream_num=" << task.stream_num + << ", private_def_size=" << task.private_def.size(); + + private_def.resize(task.private_def.size()); + auto sec_ret = memcpy_s(private_def.data(), private_def.size(), task.private_def.data(), task.private_def.size()); + if (sec_ret != 0) { + MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; + } + + void *workspace_addr = nullptr; + if (task.workspace_size != 0) { + if (workspace.empty()) { + MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node->DebugString() << " is empty"; + } + MS_EXCEPTION_IF_NULL(workspace.at(0)); + workspace_addr = workspace.at(0)->addr; + } + + results.emplace_back(std::make_shared( + unique_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr, output_data_addr, + workspace_addr, task.workspace_size, task.stream_num, private_def, + hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type_, group_, + NeedDump())); + } + + return results; +} MS_HCCL_REG_KERNEL(AllToAllv, HcomAllToAllKernel); } // namespace mindspore::kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.h index 70721c504d1..1b77f5989df 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_to_all.h @@ -26,8 +26,15 @@ class HcomAllToAllKernel : public HcclKernel { public: HcomAllToAllKernel(); ~HcomAllToAllKernel() override; + bool Init(const AnfNodePtr &anf_node) override; bool Launch(const std::vector &, const std::vector &, const std::vector &, void *) override; + const std::vector &GetOutputSizeList() const override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + HcclDataType data_type_; }; } // namespace mindspore::kernel #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_TO_ALL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc index 033f20ee234..b5c8763b000 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -54,6 +54,14 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vectorsecond; +} + bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(data_type_list); @@ -69,17 +77,14 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector } else { type_ptr = AnfAlgo::GetInputDeviceDataType(anf_node, i); } - auto iter = kConstOpHcomDataTypeMap.find(type_ptr); - if (iter == kConstOpHcomDataTypeMap.end()) { - MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr; - } - data_type_list->emplace_back(iter->second); + data_type_list->emplace_back(ConvertHcclType(type_ptr)); } - auto type_base = *(std::begin(*data_type_list)); - if (std::any_of(data_type_list->begin(), data_type_list->end(), - [&type_base](HcclDataType type) { return type != type_base; })) { - MS_LOG(ERROR) << "hccl have different data type"; - return false; + if (!data_type_list->empty()) { + if (std::any_of(data_type_list->begin(), data_type_list->end(), + [&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) { + MS_LOG(ERROR) << "hccl have different data type"; + return false; + } } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h index c08c6762386..915554e110d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h @@ -59,6 +59,7 @@ class HcomUtil { public: static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_shape_list); static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_shape_list); + static ::HcclDataType ConvertHcclType(TypeId type_id); static bool GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list); static bool GetHcclOpSize(const HcclDataType &data_type, const vector &shape, size_t *size); static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size); diff --git a/mindspore/core/ops/neighborexchange.cc b/mindspore/core/ops/neighborexchange.cc index dcc0dc5d8e6..efa6a6ab46e 100644 --- a/mindspore/core/ops/neighborexchange.cc +++ b/mindspore/core/ops/neighborexchange.cc @@ -131,7 +131,7 @@ void Check(const PrimitivePtr &primitive, const std::vector &in } } -abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive) { MS_EXCEPTION_IF_NULL(primitive); auto recv_shapes = primitive->GetAttr(kRecvShapes); MS_EXCEPTION_IF_NULL(recv_shapes); @@ -147,15 +147,14 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec MS_EXCEPTION_IF_NULL(base_shape); base_shape_list.push_back(base_shape); } + if (base_shape_list.empty()) { + return std::make_shared(); + } return std::make_shared(base_shape_list); } -TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { +TypePtr InferType(const PrimitivePtr &primitive) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("NeighborExchange infer", SizeToLong(input_args.size()), kEqual, 1, - prim_name); - MS_EXCEPTION_IF_NULL(input_args[0]); auto recv_shapes = primitive->GetAttr(kRecvShapes); MS_EXCEPTION_IF_NULL(recv_shapes); auto shapes_seq = recv_shapes->cast(); @@ -165,14 +164,17 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vectorGetAttr(kRecvType)->cast(); MS_EXCEPTION_IF_NULL(recv_type); std::vector type_vec(out_num, recv_type); + if (type_vec.empty()) { + return std::make_shared(); + } return std::make_shared(type_vec); } } // namespace AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { Check(primitive, input_args); - auto type = InferType(primitive, input_args); - auto shape = InferShape(primitive, input_args); + auto type = InferType(primitive); + auto shape = InferShape(primitive); return abstract::MakeAbstract(shape, type); } REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true); diff --git a/tests/ut/python/parallel/test_neighborexchange.py b/tests/ut/python/parallel/test_neighborexchange.py index f1d0003f51e..a2963186506 100644 --- a/tests/ut/python/parallel/test_neighborexchange.py +++ b/tests/ut/python/parallel/test_neighborexchange.py @@ -91,7 +91,51 @@ def test_NeighborExchange_single_input_success(): compile_net(net) -def test_NeighborExchage_empty_send_empty_recv_success(): +def test_NeighborExchange_empty_send_success(): + """ + Feature: NeighborExchange + Description: empty inputs, with valid arguments + Expectation: success + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[1], recv_shapes=([1],), + send_shapes=(), recv_type=ms.float32) + + def construct(self, x1): + self.alltoallv() + return x1 + + net = Net() + _executor.compile(net, _x1) + + +def test_NeighborExchange_empty_recv_success(): + """ + Feature: NeighborExchange + Description: empty outputs, with valid arguments + Expectation: success + """ + context.set_auto_parallel_context(device_num=8, global_rank=0) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[], recv_shapes=(), + send_shapes=([32, 16],), recv_type=ms.float32) + + def construct(self, x1): + self.alltoallv((x1,)) + return x1 + + net = Net() + _executor.compile(net, _x1) + + +def test_NeighborExchange_empty_send_empty_recv_success(): """ Feature: NeighborExchange Description: empty inputs and empty outputs, with valid arguments @@ -102,20 +146,18 @@ def test_NeighborExchage_empty_send_empty_recv_success(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], - recv_shapes=(), - send_shapes=(), recv_type=ms.float32, group=("str",)) + self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], recv_shapes=(), + send_shapes=(), recv_type=ms.float32) def construct(self, x1): self.alltoallv() return x1 net = Net() - with pytest.raises(TypeError): - _executor.compile(net, _x1) + _executor.compile(net, _x1) -def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed(): +def test_NeighborExchange_recv_shape_num_diff_with_recv_rank_size_failed(): """ Feature: NeighborExchange Description: send_rank_ids and send_shapes are set as 1 input, but gives 2 @@ -143,7 +185,7 @@ def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed(): compile_net(net) -def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed(): +def test_NeighborExchange_send_shape_num_diff_with_send_rank_size_failed(): """ Feature: NeighborExchange Description: send_rank_ids is set as 2 inputs, but send_shapes are set as 1 input @@ -172,7 +214,7 @@ def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed(): compile_net(net) -def test_NeighborExchage_send_shape_num_diff_with_input_num_failed(): +def test_NeighborExchange_send_shape_num_diff_with_input_num_failed(): """ Feature: NeighborExchange Description: send_rank_ids and send_shapes are set as 2 inputs, but has only 1 input @@ -201,7 +243,7 @@ def test_NeighborExchage_send_shape_num_diff_with_input_num_failed(): compile_net(net) -def test_NeighborExchage_send_shape_diff_with_input_shape_failed(): +def test_NeighborExchange_send_shape_diff_with_input_shape_failed(): """ Feature: NeighborExchange Description: send_shapes is set as [16, 16], but input is [32, 32] @@ -229,7 +271,7 @@ def test_NeighborExchage_send_shape_diff_with_input_shape_failed(): compile_net(net) -def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed(): +def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed(): """ Feature: NeighborExchange Description: send_rank_ids should be list, but a tuple is given @@ -252,7 +294,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed(): _executor.compile(net, _x1) -def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed(): +def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed(): """ Feature: NeighborExchange Description: send_rank_ids should be int, but a float is given @@ -276,7 +318,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed(): _executor.compile(net, _x1) -def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed(): +def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed(): """ Feature: NeighborExchange Description: recv_rank_ids should be list, but a tuple is given @@ -300,7 +342,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed(): _executor.compile(net, _x1) -def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed(): +def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed(): """ Feature: NeighborExchange Description: recv_rank_ids should be int, but a float is given @@ -324,7 +366,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed(): _executor.compile(net, _x1) -def test_NeighborExchage_attr_check_send_shape_not_tuple_failed(): +def test_NeighborExchange_attr_check_send_shape_not_tuple_failed(): """ Feature: NeighborExchange Description: send_shapes should be tuple(list), but a list is given @@ -348,7 +390,7 @@ def test_NeighborExchage_attr_check_send_shape_not_tuple_failed(): _executor.compile(net, _x1) -def test_NeighborExchage_attr_check_recv_type_numpy_failed(): +def test_NeighborExchange_attr_check_recv_type_numpy_failed(): """ Feature: NeighborExchange Description: recv_type should be mindspore type, but a numpy type is given @@ -372,7 +414,7 @@ def test_NeighborExchage_attr_check_recv_type_numpy_failed(): _executor.compile(net, _x1) -def test_NeighborExchage_attr_invalid_grpup_failed(): +def test_NeighborExchange_attr_invalid_grpup_failed(): """ Feature: NeighborExchange Description: group should be str, but a tuple is given