gatherv2_support_host_and_device

This commit is contained in:
lichenever 2020-06-01 10:32:32 +08:00
parent b9ba99bb13
commit 1437966c98
5 changed files with 165 additions and 11 deletions

View File

@ -44,6 +44,18 @@ Status GatherV2PInfo::GetAttrs() {
}
axis_ = axis;
// get target
auto target_iter = attrs_.find(TARGET);
if (target_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(target_iter->second);
if (target_iter->second->isa<StringImm>()) {
target_ = target_iter->second->cast<StringImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
return FAILED;
}
}
return SUCCESS;
}
@ -61,8 +73,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
auto param_shape = inputs_shape_.at(0);
auto param_strategy = strategy->GetInputDim().at(0);
auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
if (slice_shape % 8 != 0) {
MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
if (slice_shape % 8 != 0 && slice_shape != 1) {
MS_LOG(DEBUG) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
return FAILED;
}
@ -74,20 +86,20 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
// don't support scalar index
if (inputs_shape_.at(1).size() == 0) {
MS_LOG(ERROR) << name_ << ": Don't support scalar index.";
MS_LOG(DEBUG) << name_ << ": Don't support scalar index.";
return FAILED;
}
// axis=0, index_shape(0)%param_strategy(0) must be 0
Shape index_shape = inputs_shape_.at(1);
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) {
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
return FAILED;
}
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
return FAILED;
}
@ -95,7 +107,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
auto index_strategy = strategy->GetInputDim().at(1);
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>());
if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) {
MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
return FAILED;
}
@ -104,7 +116,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED;
}
@ -290,18 +302,85 @@ Status GatherV2PInfo::InferBias() {
}
Status GatherV2PInfo::InferGroup() {
std::vector<Group> group_list;
auto param_strategy = strategy_->GetInputDim().at(0);
size_t dim = IntToSize(axis_);
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
dim = (axis_ + 1) % 2;
}
if (CreateGroupByDim(dim, &group_list) != SUCCESS) {
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
int32_t rank = g_device_manager->global_rank();
RankList dev_list = g_device_manager->GetDeviceListByStageId(0);
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed.";
return FAILED;
}
if (group_devices.size() == 1) {
MS_LOG(INFO) << "the group is empty";
return SUCCESS;
}
group_ = group_list.at(0);
group_ = g_device_manager->CreateGroup(group_devices);
return SUCCESS;
}
std::vector<int32_t> GetRankFromGroup(const Group &group) {
std::vector<int32_t> rank_list;
auto device_list = group.GetDevicesList();
for (auto &device : device_list) {
rank_list.insert(rank_list.end(), device.rank() % 8);
}
return rank_list;
}
Status GatherV2PInfo::InferForwardCommunication() {
forward_op_.clear();
if (target_ != CPU) {
return SUCCESS;
}
auto param_strategy = strategy_->GetInputDim().at(0);
// don't split axis, no need forward communication
if (param_strategy.at(IntToSize(axis_)) == 1) {
return SUCCESS;
}
// split axis
OperatorName operator_name;
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED;
}
auto group_size = group_.GetDevNum();
Attr attr_group;
// group size <= 8
std::vector<int32_t> rank_list;
if (group_size <= 8) {
reduce_scatter_flag_ = false;
operator_name = HOST_REDUCE_SCATTER;
rank_list = GetRankFromGroup(group_);
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
} else {
// group size > 8
reduce_scatter_flag_ = true;
split_num_ = SizeToInt(group_size / 8);
CheckGlobalDeviceManager();
operator_name = REDUCE_SCATTER;
int32_t rank = g_device_manager->global_rank();
size_t repeat = group_size / 8;
for (size_t i = 0; i < repeat; ++i) {
rank_list.push_back(rank + SizeToInt(i * 8));
}
Group g = g_device_manager->CreateGroup(rank_list);
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
}
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
OperatorAttrs attrs = {attr_op, attr_group};
OperatorParams params;
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(operator_name, args);
forward_op_.push_back(op);
return SUCCESS;
}
@ -346,6 +425,10 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
auto param_strategy = strategy_->GetInputDim().at(0);
// target_ == CPU, no need to raplace graph
if (target_ == CPU) {
return nullptr;
}
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
@ -353,11 +436,34 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
return replace_graph_;
}
Status GatherV2PInfo::ComputeReplaceOp() {
if (InferBias() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
return FAILED;
}
OperatorName op_name = EMBEDDING_LOOKUP;
OperatorAttrs attrs;
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5),
std::make_pair(param_split_num, 6)};
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(op_name, args);
replace_op_.push_back(op);
return SUCCESS;
}
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
// only target_ == CPU, we need to replace op
if (target_ == CPU && ComputeReplaceOp() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

View File

@ -49,7 +49,7 @@ class GatherV2PInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
@ -57,14 +57,18 @@ class GatherV2PInfo : public OperatorInfo {
private:
Status ComputeReplaceGraph(const CNodePtr &cnode);
Status ComputeReplaceOp();
Status InferBias();
Status InferGroup();
int32_t axis_;
std::string target_;
int32_t bias_;
int32_t slice_size_;
Shape out_dev_matrix_shape_;
Group group_;
bool reduce_scatter_flag_ = false;
int32_t split_num_ = 1;
};
} // namespace parallel
} // namespace mindspore

View File

@ -76,6 +76,8 @@ constexpr char DEPEND[] = "depend";
constexpr char BATCH_PARALLEL[] = "BatchParallel";
constexpr char ACTIVATION_TYPE[] = "activation_type";
constexpr char TARGET[] = "target";
constexpr char CPU[] = "CPU";
constexpr char TRANSPOSE_A[] = "transpose_a";
constexpr char TRANSPOSE_B[] = "transpose_b";
constexpr char SHAPE[] = "shape";
@ -141,6 +143,8 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
constexpr char STRIDED_SLICE[] = "StridedSlice";
constexpr char ALL_GATHER[] = "AllGather";
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
constexpr char HOST_REDUCE_SCATTER[] = "HostReduceScatter";
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";
constexpr char CONCAT[] = "Concat";
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";

View File

@ -534,6 +534,10 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
}
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
if (prim->name() == GATHERV2) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)};
}
if (!params.empty()) {
Param param_first = *(params.begin());
int32_t first_position = param_first.second;

View File

@ -182,3 +182,39 @@ def test_gatherv2_auto1():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
_executor.compile(net, x, y)