forked from mindspore-Ecosystem/mindspore
fix neighborexchange empty input case
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
c24dc871e0
commit
e5a1582e4b
|
@ -39,7 +39,7 @@ std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
|||
if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
if (op_name == kReceive || op_name == kHcomSend) {
|
||||
if (op_name == kReceive || op_name == kHcomSend || op_name == kAllToAllv) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/hccl/hcom_all_to_all.h"
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#include "runtime/device/ascend/ge_runtime/task_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
HcomAllToAllKernel::HcomAllToAllKernel() {}
|
||||
|
@ -25,5 +28,87 @@ bool HcomAllToAllKernel::Launch(const std::vector<AddressPtr> &, const std::vect
|
|||
return true;
|
||||
}
|
||||
|
||||
bool HcomAllToAllKernel::Init(const AnfNodePtr &anf_node) {
|
||||
bool ret = HcclKernel::Init(anf_node);
|
||||
if (!ret) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (hccl_data_type_list_.empty()) {
|
||||
auto recv_type = AnfAlgo::GetNodeAttr<TypePtr>(anf_node, kAttrRecvType);
|
||||
MS_EXCEPTION_IF_NULL(recv_type);
|
||||
data_type_ = HcomUtil::ConvertHcclType(recv_type->type_id());
|
||||
} else {
|
||||
data_type_ = hccl_data_type_list_[0];
|
||||
}
|
||||
|
||||
workspace_size_list_ = {LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))};
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &HcomAllToAllKernel::GetOutputSizeList() const {
|
||||
if (!output_size_list_.empty()) {
|
||||
return output_size_list_;
|
||||
}
|
||||
for (size_t i = 0; i < hccl_kernel_output_shape_list_.size(); ++i) {
|
||||
size_t size = 0;
|
||||
if (!HcomUtil::GetHcclOpSize(data_type_, hccl_kernel_output_shape_list_[i], &size)) {
|
||||
MS_LOG(EXCEPTION) << "AllToAllv get output size failed.";
|
||||
}
|
||||
output_size_list_.push_back(size);
|
||||
}
|
||||
return output_size_list_;
|
||||
}
|
||||
|
||||
std::vector<TaskInfoPtr> HcomAllToAllKernel::GenTask(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
|
||||
auto anf_node = anf_node_.lock();
|
||||
if (!anf_node) {
|
||||
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
|
||||
}
|
||||
|
||||
stream_id_ = stream_id;
|
||||
void *input_data_addr = inputs.empty() ? nullptr : inputs.at(0)->addr;
|
||||
void *output_data_addr = outputs.empty() ? nullptr : outputs.at(0)->addr;
|
||||
|
||||
std::vector<uint8_t> private_def;
|
||||
std::vector<hccl::HcclTaskInfo> task_info;
|
||||
bool ret = hccl::HcclAdapter::GetInstance().GenTask(anf_node, data_type_, &task_info);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed.";
|
||||
}
|
||||
|
||||
std::vector<TaskInfoPtr> results;
|
||||
for (auto &task : task_info) {
|
||||
MS_LOG(INFO) << "AlltoAll Task : stream_id=" << stream_id << ", count=" << hccl_count_ << ", root_id=" << root_id_
|
||||
<< ", op_type=" << static_cast<int>(op_type_) << ", data_type=" << static_cast<int>(data_type_)
|
||||
<< ", workspace_size=" << task.workspace_size << ", stream_num=" << task.stream_num
|
||||
<< ", private_def_size=" << task.private_def.size();
|
||||
|
||||
private_def.resize(task.private_def.size());
|
||||
auto sec_ret = memcpy_s(private_def.data(), private_def.size(), task.private_def.data(), task.private_def.size());
|
||||
if (sec_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
|
||||
}
|
||||
|
||||
void *workspace_addr = nullptr;
|
||||
if (task.workspace_size != 0) {
|
||||
if (workspace.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node->DebugString() << " is empty";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(workspace.at(0));
|
||||
workspace_addr = workspace.at(0)->addr;
|
||||
}
|
||||
|
||||
results.emplace_back(std::make_shared<ge::model_runner::HcclTaskInfo>(
|
||||
unique_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr, output_data_addr,
|
||||
workspace_addr, task.workspace_size, task.stream_num, private_def,
|
||||
hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type_, group_,
|
||||
NeedDump()));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
MS_HCCL_REG_KERNEL(AllToAllv, HcomAllToAllKernel);
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -26,8 +26,15 @@ class HcomAllToAllKernel : public HcclKernel {
|
|||
public:
|
||||
HcomAllToAllKernel();
|
||||
~HcomAllToAllKernel() override;
|
||||
bool Init(const AnfNodePtr &anf_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
void *) override;
|
||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||
|
||||
private:
|
||||
HcclDataType data_type_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_TO_ALL_H_
|
||||
|
|
|
@ -54,6 +54,14 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si
|
|||
return true;
|
||||
}
|
||||
|
||||
::HcclDataType HcomUtil::ConvertHcclType(TypeId type_id) {
|
||||
auto iter = kConstOpHcomDataTypeMap.find(type_id);
|
||||
if (iter == kConstOpHcomDataTypeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_id;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(data_type_list);
|
||||
|
@ -69,17 +77,14 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType>
|
|||
} else {
|
||||
type_ptr = AnfAlgo::GetInputDeviceDataType(anf_node, i);
|
||||
}
|
||||
auto iter = kConstOpHcomDataTypeMap.find(type_ptr);
|
||||
if (iter == kConstOpHcomDataTypeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr;
|
||||
}
|
||||
data_type_list->emplace_back(iter->second);
|
||||
data_type_list->emplace_back(ConvertHcclType(type_ptr));
|
||||
}
|
||||
auto type_base = *(std::begin(*data_type_list));
|
||||
if (std::any_of(data_type_list->begin(), data_type_list->end(),
|
||||
[&type_base](HcclDataType type) { return type != type_base; })) {
|
||||
MS_LOG(ERROR) << "hccl have different data type";
|
||||
return false;
|
||||
if (!data_type_list->empty()) {
|
||||
if (std::any_of(data_type_list->begin(), data_type_list->end(),
|
||||
[&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) {
|
||||
MS_LOG(ERROR) << "hccl have different data type";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -59,6 +59,7 @@ class HcomUtil {
|
|||
public:
|
||||
static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
|
||||
static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
|
||||
static ::HcclDataType ConvertHcclType(TypeId type_id);
|
||||
static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list);
|
||||
static bool GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size);
|
||||
static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size);
|
||||
|
|
|
@ -131,7 +131,7 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
|
|||
}
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto recv_shapes = primitive->GetAttr(kRecvShapes);
|
||||
MS_EXCEPTION_IF_NULL(recv_shapes);
|
||||
|
@ -147,15 +147,14 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
base_shape_list.push_back(base_shape);
|
||||
}
|
||||
if (base_shape_list.empty()) {
|
||||
return std::make_shared<abstract::Shape>();
|
||||
}
|
||||
return std::make_shared<abstract::TupleShape>(base_shape_list);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr InferType(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("NeighborExchange infer", SizeToLong(input_args.size()), kEqual, 1,
|
||||
prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto recv_shapes = primitive->GetAttr(kRecvShapes);
|
||||
MS_EXCEPTION_IF_NULL(recv_shapes);
|
||||
auto shapes_seq = recv_shapes->cast<ValueSequeuePtr>();
|
||||
|
@ -165,14 +164,17 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
auto recv_type = primitive->GetAttr(kRecvType)->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(recv_type);
|
||||
std::vector<TypePtr> type_vec(out_num, recv_type);
|
||||
if (type_vec.empty()) {
|
||||
return std::make_shared<TypeNone>();
|
||||
}
|
||||
return std::make_shared<Tuple>(type_vec);
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
Check(primitive, input_args);
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = InferType(primitive);
|
||||
auto shape = InferShape(primitive);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
|
||||
|
|
|
@ -91,7 +91,51 @@ def test_NeighborExchange_single_input_success():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_NeighborExchage_empty_send_empty_recv_success():
|
||||
def test_NeighborExchange_empty_send_success():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: empty inputs, with valid arguments
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[1], recv_shapes=([1],),
|
||||
send_shapes=(), recv_type=ms.float32)
|
||||
|
||||
def construct(self, x1):
|
||||
self.alltoallv()
|
||||
return x1
|
||||
|
||||
net = Net()
|
||||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_empty_recv_success():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: empty outputs, with valid arguments
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[], recv_shapes=(),
|
||||
send_shapes=([32, 16],), recv_type=ms.float32)
|
||||
|
||||
def construct(self, x1):
|
||||
self.alltoallv((x1,))
|
||||
return x1
|
||||
|
||||
net = Net()
|
||||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchange_empty_send_empty_recv_success():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: empty inputs and empty outputs, with valid arguments
|
||||
|
@ -102,20 +146,18 @@ def test_NeighborExchage_empty_send_empty_recv_success():
|
|||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[],
|
||||
recv_shapes=(),
|
||||
send_shapes=(), recv_type=ms.float32, group=("str",))
|
||||
self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], recv_shapes=(),
|
||||
send_shapes=(), recv_type=ms.float32)
|
||||
|
||||
def construct(self, x1):
|
||||
self.alltoallv()
|
||||
return x1
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
_executor.compile(net, _x1)
|
||||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed():
|
||||
def test_NeighborExchange_recv_shape_num_diff_with_recv_rank_size_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_rank_ids and send_shapes are set as 1 input, but gives 2
|
||||
|
@ -143,7 +185,7 @@ def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed():
|
||||
def test_NeighborExchange_send_shape_num_diff_with_send_rank_size_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_rank_ids is set as 2 inputs, but send_shapes are set as 1 input
|
||||
|
@ -172,7 +214,7 @@ def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_NeighborExchage_send_shape_num_diff_with_input_num_failed():
|
||||
def test_NeighborExchange_send_shape_num_diff_with_input_num_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_rank_ids and send_shapes are set as 2 inputs, but has only 1 input
|
||||
|
@ -201,7 +243,7 @@ def test_NeighborExchage_send_shape_num_diff_with_input_num_failed():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_NeighborExchage_send_shape_diff_with_input_shape_failed():
|
||||
def test_NeighborExchange_send_shape_diff_with_input_shape_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_shapes is set as [16, 16], but input is [32, 32]
|
||||
|
@ -229,7 +271,7 @@ def test_NeighborExchage_send_shape_diff_with_input_shape_failed():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed():
|
||||
def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_rank_ids should be list, but a tuple is given
|
||||
|
@ -252,7 +294,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed():
|
|||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed():
|
||||
def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_rank_ids should be int, but a float is given
|
||||
|
@ -276,7 +318,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed():
|
|||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed():
|
||||
def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: recv_rank_ids should be list, but a tuple is given
|
||||
|
@ -300,7 +342,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed():
|
|||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed():
|
||||
def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: recv_rank_ids should be int, but a float is given
|
||||
|
@ -324,7 +366,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed():
|
|||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchage_attr_check_send_shape_not_tuple_failed():
|
||||
def test_NeighborExchange_attr_check_send_shape_not_tuple_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: send_shapes should be tuple(list), but a list is given
|
||||
|
@ -348,7 +390,7 @@ def test_NeighborExchage_attr_check_send_shape_not_tuple_failed():
|
|||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchage_attr_check_recv_type_numpy_failed():
|
||||
def test_NeighborExchange_attr_check_recv_type_numpy_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: recv_type should be mindspore type, but a numpy type is given
|
||||
|
@ -372,7 +414,7 @@ def test_NeighborExchage_attr_check_recv_type_numpy_failed():
|
|||
_executor.compile(net, _x1)
|
||||
|
||||
|
||||
def test_NeighborExchage_attr_invalid_grpup_failed():
|
||||
def test_NeighborExchange_attr_invalid_grpup_failed():
|
||||
"""
|
||||
Feature: NeighborExchange
|
||||
Description: group should be str, but a tuple is given
|
||||
|
|
Loading…
Reference in New Issue