forked from mindspore-Ecosystem/mindspore
gatherv2_support_host_and_device
This commit is contained in:
parent
b9ba99bb13
commit
1437966c98
|
@ -44,6 +44,18 @@ Status GatherV2PInfo::GetAttrs() {
|
|||
}
|
||||
axis_ = axis;
|
||||
|
||||
// get target
|
||||
auto target_iter = attrs_.find(TARGET);
|
||||
if (target_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(target_iter->second);
|
||||
if (target_iter->second->isa<StringImm>()) {
|
||||
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -61,8 +73,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
auto param_shape = inputs_shape_.at(0);
|
||||
auto param_strategy = strategy->GetInputDim().at(0);
|
||||
auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
|
||||
if (slice_shape % 8 != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
|
||||
if (slice_shape % 8 != 0 && slice_shape != 1) {
|
||||
MS_LOG(DEBUG) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -74,20 +86,20 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
// don't support scalar index
|
||||
if (inputs_shape_.at(1).size() == 0) {
|
||||
MS_LOG(ERROR) << name_ << ": Don't support scalar index.";
|
||||
MS_LOG(DEBUG) << name_ << ": Don't support scalar index.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
||||
Shape index_shape = inputs_shape_.at(1);
|
||||
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) {
|
||||
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
||||
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
||||
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
||||
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -95,7 +107,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
auto index_strategy = strategy->GetInputDim().at(1);
|
||||
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>());
|
||||
if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) {
|
||||
MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
|
||||
MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -104,7 +116,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
||||
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
||||
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -290,18 +302,85 @@ Status GatherV2PInfo::InferBias() {
|
|||
}
|
||||
|
||||
Status GatherV2PInfo::InferGroup() {
|
||||
std::vector<Group> group_list;
|
||||
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||
size_t dim = IntToSize(axis_);
|
||||
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
|
||||
dim = (axis_ + 1) % 2;
|
||||
}
|
||||
if (CreateGroupByDim(dim, &group_list) != SUCCESS) {
|
||||
CheckGlobalDeviceManager();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
int32_t rank = g_device_manager->global_rank();
|
||||
RankList dev_list = g_device_manager->GetDeviceListByStageId(0);
|
||||
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group failed.";
|
||||
return FAILED;
|
||||
}
|
||||
if (group_devices.size() == 1) {
|
||||
MS_LOG(INFO) << "the group is empty";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
group_ = group_list.at(0);
|
||||
group_ = g_device_manager->CreateGroup(group_devices);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<int32_t> GetRankFromGroup(const Group &group) {
|
||||
std::vector<int32_t> rank_list;
|
||||
auto device_list = group.GetDevicesList();
|
||||
for (auto &device : device_list) {
|
||||
rank_list.insert(rank_list.end(), device.rank() % 8);
|
||||
}
|
||||
return rank_list;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferForwardCommunication() {
|
||||
forward_op_.clear();
|
||||
if (target_ != CPU) {
|
||||
return SUCCESS;
|
||||
}
|
||||
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||
// don't split axis, no need forward communication
|
||||
if (param_strategy.at(IntToSize(axis_)) == 1) {
|
||||
return SUCCESS;
|
||||
}
|
||||
// split axis
|
||||
OperatorName operator_name;
|
||||
if (InferGroup() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||
return FAILED;
|
||||
}
|
||||
auto group_size = group_.GetDevNum();
|
||||
Attr attr_group;
|
||||
// group size <= 8
|
||||
std::vector<int32_t> rank_list;
|
||||
if (group_size <= 8) {
|
||||
reduce_scatter_flag_ = false;
|
||||
operator_name = HOST_REDUCE_SCATTER;
|
||||
rank_list = GetRankFromGroup(group_);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
||||
} else {
|
||||
// group size > 8
|
||||
reduce_scatter_flag_ = true;
|
||||
split_num_ = SizeToInt(group_size / 8);
|
||||
CheckGlobalDeviceManager();
|
||||
operator_name = REDUCE_SCATTER;
|
||||
int32_t rank = g_device_manager->global_rank();
|
||||
size_t repeat = group_size / 8;
|
||||
for (size_t i = 0; i < repeat; ++i) {
|
||||
rank_list.push_back(rank + SizeToInt(i * 8));
|
||||
}
|
||||
Group g = g_device_manager->CreateGroup(rank_list);
|
||||
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
||||
}
|
||||
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||
OperatorAttrs attrs = {attr_op, attr_group};
|
||||
OperatorParams params;
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(operator_name, args);
|
||||
|
||||
forward_op_.push_back(op);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -346,6 +425,10 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
|
||||
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
||||
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||
// target_ == CPU, no need to raplace graph
|
||||
if (target_ == CPU) {
|
||||
return nullptr;
|
||||
}
|
||||
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
||||
return nullptr;
|
||||
|
@ -353,11 +436,34 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|||
return replace_graph_;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::ComputeReplaceOp() {
|
||||
if (InferBias() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
|
||||
return FAILED;
|
||||
}
|
||||
OperatorName op_name = EMBEDDING_LOOKUP;
|
||||
OperatorAttrs attrs;
|
||||
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
||||
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
|
||||
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
|
||||
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5),
|
||||
std::make_pair(param_split_num, 6)};
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(op_name, args);
|
||||
replace_op_.push_back(op);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
// only target_ == CPU, we need to replace op
|
||||
if (target_ == CPU && ComputeReplaceOp() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferForwardCommunication() override;
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
|
@ -57,14 +57,18 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
|
||||
private:
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
Status ComputeReplaceOp();
|
||||
Status InferBias();
|
||||
Status InferGroup();
|
||||
|
||||
int32_t axis_;
|
||||
std::string target_;
|
||||
int32_t bias_;
|
||||
int32_t slice_size_;
|
||||
Shape out_dev_matrix_shape_;
|
||||
Group group_;
|
||||
bool reduce_scatter_flag_ = false;
|
||||
int32_t split_num_ = 1;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,8 @@ constexpr char DEPEND[] = "depend";
|
|||
constexpr char BATCH_PARALLEL[] = "BatchParallel";
|
||||
|
||||
constexpr char ACTIVATION_TYPE[] = "activation_type";
|
||||
constexpr char TARGET[] = "target";
|
||||
constexpr char CPU[] = "CPU";
|
||||
constexpr char TRANSPOSE_A[] = "transpose_a";
|
||||
constexpr char TRANSPOSE_B[] = "transpose_b";
|
||||
constexpr char SHAPE[] = "shape";
|
||||
|
@ -141,6 +143,8 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
|
|||
constexpr char STRIDED_SLICE[] = "StridedSlice";
|
||||
constexpr char ALL_GATHER[] = "AllGather";
|
||||
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
|
||||
constexpr char HOST_REDUCE_SCATTER[] = "HostReduceScatter";
|
||||
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";
|
||||
constexpr char CONCAT[] = "Concat";
|
||||
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
|
||||
constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";
|
||||
|
|
|
@ -534,6 +534,10 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
|
|||
MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
|
||||
}
|
||||
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
|
||||
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
if (prim->name() == GATHERV2) {
|
||||
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)};
|
||||
}
|
||||
if (!params.empty()) {
|
||||
Param param_first = *(params.begin());
|
||||
int32_t first_position = param_first.second;
|
||||
|
|
|
@ -182,3 +182,39 @@ def test_gatherv2_auto1():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_cpu0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_cpu1():
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_cpu2():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((1, 8), (1, 1))
|
||||
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||
net = NetWithLoss(Net(0, strategy1, strategy2))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
|
Loading…
Reference in New Issue