fix neighborexchange empty input case

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-08-25 20:36:57 +08:00
parent c24dc871e0
commit e5a1582e4b
7 changed files with 178 additions and 36 deletions

View File

@ -39,7 +39,7 @@ std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) {
return kOpFormat_DEFAULT;
}
if (op_name == kReceive || op_name == kHcomSend) {
if (op_name == kReceive || op_name == kHcomSend || op_name == kAllToAllv) {
return kOpFormat_DEFAULT;
}
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);

View File

@ -14,6 +14,9 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/hccl/hcom_all_to_all.h"
#include "runtime/hccl_adapter/hccl_adapter.h"
#include "runtime/device/ascend/ge_runtime/task_info.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::kernel {
HcomAllToAllKernel::HcomAllToAllKernel() {}
@ -25,5 +28,87 @@ bool HcomAllToAllKernel::Launch(const std::vector<AddressPtr> &, const std::vect
return true;
}
bool HcomAllToAllKernel::Init(const AnfNodePtr &anf_node) {
bool ret = HcclKernel::Init(anf_node);
if (!ret) {
return ret;
}
if (hccl_data_type_list_.empty()) {
auto recv_type = AnfAlgo::GetNodeAttr<TypePtr>(anf_node, kAttrRecvType);
MS_EXCEPTION_IF_NULL(recv_type);
data_type_ = HcomUtil::ConvertHcclType(recv_type->type_id());
} else {
data_type_ = hccl_data_type_list_[0];
}
workspace_size_list_ = {LongToSize(hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node, data_type_))};
return true;
}
const std::vector<size_t> &HcomAllToAllKernel::GetOutputSizeList() const {
if (!output_size_list_.empty()) {
return output_size_list_;
}
for (size_t i = 0; i < hccl_kernel_output_shape_list_.size(); ++i) {
size_t size = 0;
if (!HcomUtil::GetHcclOpSize(data_type_, hccl_kernel_output_shape_list_[i], &size)) {
MS_LOG(EXCEPTION) << "AllToAllv get output size failed.";
}
output_size_list_.push_back(size);
}
return output_size_list_;
}
std::vector<TaskInfoPtr> HcomAllToAllKernel::GenTask(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
auto anf_node = anf_node_.lock();
if (!anf_node) {
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
}
stream_id_ = stream_id;
void *input_data_addr = inputs.empty() ? nullptr : inputs.at(0)->addr;
void *output_data_addr = outputs.empty() ? nullptr : outputs.at(0)->addr;
std::vector<uint8_t> private_def;
std::vector<hccl::HcclTaskInfo> task_info;
bool ret = hccl::HcclAdapter::GetInstance().GenTask(anf_node, data_type_, &task_info);
if (!ret) {
MS_LOG(EXCEPTION) << "Gen Task for " << anf_node->DebugString() << " failed.";
}
std::vector<TaskInfoPtr> results;
for (auto &task : task_info) {
MS_LOG(INFO) << "AlltoAll Task : stream_id=" << stream_id << ", count=" << hccl_count_ << ", root_id=" << root_id_
<< ", op_type=" << static_cast<int>(op_type_) << ", data_type=" << static_cast<int>(data_type_)
<< ", workspace_size=" << task.workspace_size << ", stream_num=" << task.stream_num
<< ", private_def_size=" << task.private_def.size();
private_def.resize(task.private_def.size());
auto sec_ret = memcpy_s(private_def.data(), private_def.size(), task.private_def.data(), task.private_def.size());
if (sec_ret != 0) {
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
}
void *workspace_addr = nullptr;
if (task.workspace_size != 0) {
if (workspace.empty()) {
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node->DebugString() << " is empty";
}
MS_EXCEPTION_IF_NULL(workspace.at(0));
workspace_addr = workspace.at(0)->addr;
}
results.emplace_back(std::make_shared<ge::model_runner::HcclTaskInfo>(
unique_name_, stream_id, hccl::HcclAdapter::GetHcclType(anf_node), input_data_addr, output_data_addr,
workspace_addr, task.workspace_size, task.stream_num, private_def,
hccl::HcclAdapter::GetInstance().GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type_, group_,
NeedDump()));
}
return results;
}
MS_HCCL_REG_KERNEL(AllToAllv, HcomAllToAllKernel);
} // namespace mindspore::kernel

View File

@ -26,8 +26,15 @@ class HcomAllToAllKernel : public HcclKernel {
public:
HcomAllToAllKernel();
~HcomAllToAllKernel() override;
bool Init(const AnfNodePtr &anf_node) override;
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *) override;
const std::vector<size_t> &GetOutputSizeList() const override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
private:
HcclDataType data_type_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_ALL_TO_ALL_H_

View File

@ -54,6 +54,14 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si
return true;
}
::HcclDataType HcomUtil::ConvertHcclType(TypeId type_id) {
auto iter = kConstOpHcomDataTypeMap.find(type_id);
if (iter == kConstOpHcomDataTypeMap.end()) {
MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_id;
}
return iter->second;
}
bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(data_type_list);
@ -69,17 +77,14 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType>
} else {
type_ptr = AnfAlgo::GetInputDeviceDataType(anf_node, i);
}
auto iter = kConstOpHcomDataTypeMap.find(type_ptr);
if (iter == kConstOpHcomDataTypeMap.end()) {
MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr;
}
data_type_list->emplace_back(iter->second);
data_type_list->emplace_back(ConvertHcclType(type_ptr));
}
auto type_base = *(std::begin(*data_type_list));
if (std::any_of(data_type_list->begin(), data_type_list->end(),
[&type_base](HcclDataType type) { return type != type_base; })) {
MS_LOG(ERROR) << "hccl have different data type";
return false;
if (!data_type_list->empty()) {
if (std::any_of(data_type_list->begin(), data_type_list->end(),
[&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) {
MS_LOG(ERROR) << "hccl have different data type";
return false;
}
}
return true;
}

View File

@ -59,6 +59,7 @@ class HcomUtil {
public:
static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
static ::HcclDataType ConvertHcclType(TypeId type_id);
static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list);
static bool GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size);
static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size);

View File

@ -131,7 +131,7 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
}
}
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto recv_shapes = primitive->GetAttr(kRecvShapes);
MS_EXCEPTION_IF_NULL(recv_shapes);
@ -147,15 +147,14 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
MS_EXCEPTION_IF_NULL(base_shape);
base_shape_list.push_back(base_shape);
}
if (base_shape_list.empty()) {
return std::make_shared<abstract::Shape>();
}
return std::make_shared<abstract::TupleShape>(base_shape_list);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
TypePtr InferType(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("NeighborExchange infer", SizeToLong(input_args.size()), kEqual, 1,
prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto recv_shapes = primitive->GetAttr(kRecvShapes);
MS_EXCEPTION_IF_NULL(recv_shapes);
auto shapes_seq = recv_shapes->cast<ValueSequeuePtr>();
@ -165,14 +164,17 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
auto recv_type = primitive->GetAttr(kRecvType)->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(recv_type);
std::vector<TypePtr> type_vec(out_num, recv_type);
if (type_vec.empty()) {
return std::make_shared<TypeNone>();
}
return std::make_shared<Tuple>(type_vec);
}
} // namespace
AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
Check(primitive, input_args);
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
auto type = InferType(primitive);
auto shape = InferShape(primitive);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);

View File

@ -91,7 +91,51 @@ def test_NeighborExchange_single_input_success():
compile_net(net)
def test_NeighborExchage_empty_send_empty_recv_success():
def test_NeighborExchange_empty_send_success():
"""
Feature: NeighborExchange
Description: empty inputs, with valid arguments
Expectation: success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[1], recv_shapes=([1],),
send_shapes=(), recv_type=ms.float32)
def construct(self, x1):
self.alltoallv()
return x1
net = Net()
_executor.compile(net, _x1)
def test_NeighborExchange_empty_recv_success():
"""
Feature: NeighborExchange
Description: empty outputs, with valid arguments
Expectation: success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[], recv_shapes=(),
send_shapes=([32, 16],), recv_type=ms.float32)
def construct(self, x1):
self.alltoallv((x1,))
return x1
net = Net()
_executor.compile(net, _x1)
def test_NeighborExchange_empty_send_empty_recv_success():
"""
Feature: NeighborExchange
Description: empty inputs and empty outputs, with valid arguments
@ -102,20 +146,18 @@ def test_NeighborExchage_empty_send_empty_recv_success():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[],
recv_shapes=(),
send_shapes=(), recv_type=ms.float32, group=("str",))
self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], recv_shapes=(),
send_shapes=(), recv_type=ms.float32)
def construct(self, x1):
self.alltoallv()
return x1
net = Net()
with pytest.raises(TypeError):
_executor.compile(net, _x1)
_executor.compile(net, _x1)
def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed():
def test_NeighborExchange_recv_shape_num_diff_with_recv_rank_size_failed():
"""
Feature: NeighborExchange
Description: send_rank_ids and send_shapes are set as 1 input, but gives 2
@ -143,7 +185,7 @@ def test_NeighborExchage_recv_shape_num_diff_with_recv_rank_size_failed():
compile_net(net)
def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed():
def test_NeighborExchange_send_shape_num_diff_with_send_rank_size_failed():
"""
Feature: NeighborExchange
Description: send_rank_ids is set as 2 inputs, but send_shapes are set as 1 input
@ -172,7 +214,7 @@ def test_NeighborExchage_send_shape_num_diff_with_send_rank_size_failed():
compile_net(net)
def test_NeighborExchage_send_shape_num_diff_with_input_num_failed():
def test_NeighborExchange_send_shape_num_diff_with_input_num_failed():
"""
Feature: NeighborExchange
Description: send_rank_ids and send_shapes are set as 2 inputs, but has only 1 input
@ -201,7 +243,7 @@ def test_NeighborExchage_send_shape_num_diff_with_input_num_failed():
compile_net(net)
def test_NeighborExchage_send_shape_diff_with_input_shape_failed():
def test_NeighborExchange_send_shape_diff_with_input_shape_failed():
"""
Feature: NeighborExchange
Description: send_shapes is set as [16, 16], but input is [32, 32]
@ -229,7 +271,7 @@ def test_NeighborExchage_send_shape_diff_with_input_shape_failed():
compile_net(net)
def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed():
def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed():
"""
Feature: NeighborExchange
Description: send_rank_ids should be list, but a tuple is given
@ -252,7 +294,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_tuple_failed():
_executor.compile(net, _x1)
def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed():
def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed():
"""
Feature: NeighborExchange
Description: send_rank_ids should be int, but a float is given
@ -276,7 +318,7 @@ def test_NeighborExchage_attr_check_send_rank_ids_is_float_failed():
_executor.compile(net, _x1)
def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed():
def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed():
"""
Feature: NeighborExchange
Description: recv_rank_ids should be list, but a tuple is given
@ -300,7 +342,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_tuple_failed():
_executor.compile(net, _x1)
def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed():
def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed():
"""
Feature: NeighborExchange
Description: recv_rank_ids should be int, but a float is given
@ -324,7 +366,7 @@ def test_NeighborExchage_attr_check_recv_rank_ids_is_float_failed():
_executor.compile(net, _x1)
def test_NeighborExchage_attr_check_send_shape_not_tuple_failed():
def test_NeighborExchange_attr_check_send_shape_not_tuple_failed():
"""
Feature: NeighborExchange
Description: send_shapes should be tuple(list), but a list is given
@ -348,7 +390,7 @@ def test_NeighborExchage_attr_check_send_shape_not_tuple_failed():
_executor.compile(net, _x1)
def test_NeighborExchage_attr_check_recv_type_numpy_failed():
def test_NeighborExchange_attr_check_recv_type_numpy_failed():
"""
Feature: NeighborExchange
Description: recv_type should be mindspore type, but a numpy type is given
@ -372,7 +414,7 @@ def test_NeighborExchage_attr_check_recv_type_numpy_failed():
_executor.compile(net, _x1)
def test_NeighborExchage_attr_invalid_grpup_failed():
def test_NeighborExchange_attr_invalid_grpup_failed():
"""
Feature: NeighborExchange
Description: group should be str, but a tuple is given