forked from mindspore-Ecosystem/mindspore
!22096 add parallel sparse attention ops
Merge pull request !22096 from yao_yf/parallel_sparse_attention_ops
This commit is contained in:
commit
5b8f9ff398
|
@ -1744,5 +1744,35 @@ void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &)
|
||||||
is_inputs_should_in_memory_[i] = is_parameter_[i];
|
is_inputs_should_in_memory_[i] = is_parameter_[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double MatmulDDSCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||||
|
int64_t) const {
|
||||||
|
double result = 0.0;
|
||||||
|
if (inputs_type_lengths_.size() != inputs.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t index = 0; index < inputs.size(); ++index) {
|
||||||
|
TensorInfo tensor_info = inputs[index];
|
||||||
|
Shape slice_shape = tensor_info.slice_shape();
|
||||||
|
result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not taking account of output
|
||||||
|
void MatmulDDSCost::CalculateOutputInMemory() {
|
||||||
|
is_output_should_in_memory_ =
|
||||||
|
(std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Taking account of input
|
||||||
|
void MatmulDDSCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
|
||||||
|
bool keep_mem =
|
||||||
|
(std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
|
||||||
|
(std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
|
||||||
|
std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1051,6 +1051,44 @@ class GatherV2PCost : public GatherV2Cost {
|
||||||
int64_t axis_;
|
int64_t axis_;
|
||||||
Shape strategy_;
|
Shape strategy_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class MatmulDDSCost : public OperatorCost {
|
||||||
|
public:
|
||||||
|
MatmulDDSCost() : OperatorCost() {}
|
||||||
|
~MatmulDDSCost() override = default;
|
||||||
|
|
||||||
|
// per device communication cost
|
||||||
|
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int64_t stage_id) const override {
|
||||||
|
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
|
||||||
|
}
|
||||||
|
double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int64_t stage_id) const override {
|
||||||
|
return 0.0;
|
||||||
|
};
|
||||||
|
double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int64_t stage_id) const override {
|
||||||
|
return 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// per device computation cost
|
||||||
|
double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int64_t stage_id) const override {
|
||||||
|
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
|
||||||
|
}
|
||||||
|
double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int64_t stage_id) const override;
|
||||||
|
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int64_t stage_id) const override {
|
||||||
|
return 0.0;
|
||||||
|
};
|
||||||
|
// Not taking account of output
|
||||||
|
void CalculateOutputInMemory() override;
|
||||||
|
// Taking account of input
|
||||||
|
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||||
|
};
|
||||||
|
using MatmulDDSCostPtr = std::shared_ptr<MatmulDDSCost>;
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
|
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
|
||||||
|
|
|
@ -204,6 +204,7 @@ REGISTER(MaxPoolInfo);
|
||||||
REGISTER(AvgPoolInfo);
|
REGISTER(AvgPoolInfo);
|
||||||
REGISTER(GatherDInfo);
|
REGISTER(GatherDInfo);
|
||||||
REGISTER(ReduceAnyInfo);
|
REGISTER(ReduceAnyInfo);
|
||||||
|
REGISTER(MatmulDDSInfo);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -165,7 +165,7 @@ AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
|
||||||
if (inputs.size() < 2) {
|
if (inputs.size() < 2) {
|
||||||
MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
|
MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
|
||||||
}
|
}
|
||||||
(void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0]
|
(void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[1]
|
||||||
auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
|
auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
|
MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
|
||||||
return new_anf_node_ptr;
|
return new_anf_node_ptr;
|
||||||
|
|
|
@ -0,0 +1,320 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "frontend/parallel/ops_info/matmul_dds_info.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "frontend/parallel/device_manager.h"
|
||||||
|
#include "frontend/parallel/device_matrix.h"
|
||||||
|
#include "frontend/parallel/step_parallel.h"
|
||||||
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
/*
|
||||||
|
* MatmulDDS has 4 input
|
||||||
|
* q, k: A 4D float used in transformer model,
|
||||||
|
* the shape is [num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16], num_heads*size_per_head = embedding_size
|
||||||
|
* The shape is reshaped for cube. origin shape is
|
||||||
|
* (bs*seq_len, embedding_size) <=> (bs, num_heads, seq_len, size_per_head)
|
||||||
|
* local_mask: Local mask in sparse attention, the shape is
|
||||||
|
* (seq_len // 16, bs * block_size // 16, 16, 16).
|
||||||
|
* block_num = seq_len // block_size, block_size = 64, always.
|
||||||
|
* global_mask: Global mask in sparse attention,
|
||||||
|
* the shape is (bs * global_size // 16, seq_len // 16, 16, 16)
|
||||||
|
* seq_len = 1024, global_size = 256, always.
|
||||||
|
* Only bs and num_heads can be splited, thus the q[0] should at least be size_per_head,
|
||||||
|
* q[1] should at least be seq_len // 16. The strategy check can use bs/head from attrs.
|
||||||
|
*/
|
||||||
|
Status MatmulDDSInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
Strategys stras = strategy->GetInputDim();
|
||||||
|
if (stras.size() != MATMUL_DDS_INPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 4.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
for (auto stra : stras) {
|
||||||
|
if (stra.size() != MATMUL_DDS_STRATEGY_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_
|
||||||
|
<< ": Invalid strategy. The strategy size should be 4, but in current dim, "
|
||||||
|
"the strategy is"
|
||||||
|
<< stra;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (stras[0][0] != stras[1][0] || num_heads_ % stras[0][0] != 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][0]:" << stras[0][0]
|
||||||
|
<< " should be equal to strategys[1][0]:" << stras[1][0]
|
||||||
|
<< " ,and should be divisible by num_heads: " << num_heads_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (stras[0][1] != stras[1][1] || stras[0][1] != stras[2][1] || stras[0][1] != stras[3][0]) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][1]:" << stras[0][1]
|
||||||
|
<< ", strategys[1][1]:" << stras[1][1] << ", strategys[2][1]:" << stras[2][1]
|
||||||
|
<< ", strategys[3][0]:" << stras[3][0] << " should be the same.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (batch_size_ % stras[0][1] != 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][1]:" << stras[0][1]
|
||||||
|
<< " should be divisible by batch_sizes:" << batch_size_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
for (size_t i = 2; i < stras[0].size(); ++i) {
|
||||||
|
if (stras[0][i] != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][" << i << "] only support 1";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 2; i < stras[1].size(); ++i) {
|
||||||
|
if (stras[1][i] != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[1][" << i << "] only support 1";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < stras[2].size(); ++i) {
|
||||||
|
if (i != 1 && stras[2][i] != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[2][" << i << "] only support 1";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < stras[3].size(); ++i) {
|
||||||
|
if (stras[3][i] != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[3][" << i << "] only support 1";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* device matrix is extended by the strategy0.
|
||||||
|
*/
|
||||||
|
Status MatmulDDSInfo::InferDevMatrixShape() {
|
||||||
|
Strategys stra = strategy_->GetInputDim();
|
||||||
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
input_strategy_ = input_strategy;
|
||||||
|
dev_matrix_shape_ = input_strategy;
|
||||||
|
dev_matrix_shape_.push_back(1);
|
||||||
|
dev_matrix_shape_.push_back(1);
|
||||||
|
dev_matrix_shape_.push_back(1);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* q: [num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16]
|
||||||
|
* k: [num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16]
|
||||||
|
* local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16)
|
||||||
|
* global_mask: (bs * global_size // 16, seq_len // 16, 16, 16)
|
||||||
|
* local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16)
|
||||||
|
* global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16)
|
||||||
|
* device_matrix: [num_heads_stra, bs_stra, 1, 1, 1, 1, 1]
|
||||||
|
*/
|
||||||
|
Status MatmulDDSInfo::InferTensorMap() {
|
||||||
|
TensorMap input_tensor_map_q;
|
||||||
|
// input_tensor_map_q [6, 5, -1, -1]
|
||||||
|
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||||
|
if (i <= 1) {
|
||||||
|
input_tensor_map_q.push_back((int64_t)(inputs_shape_[0].size() + 3 - i - 1));
|
||||||
|
} else {
|
||||||
|
input_tensor_map_q.push_back((int64_t)(MAP_NONE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TensorMap input_tensor_map_k;
|
||||||
|
// input_tensor_map_k [6, 5, -1, -1]
|
||||||
|
for (size_t i = 0; i < inputs_shape_[1].size(); ++i) {
|
||||||
|
if (i <= 1) {
|
||||||
|
input_tensor_map_k.push_back((int64_t)(inputs_shape_[1].size() + 3 - i - 1));
|
||||||
|
} else {
|
||||||
|
input_tensor_map_k.push_back((int64_t)(MAP_NONE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TensorMap input_tensor_map_local_mask;
|
||||||
|
// input_tensor_map_local_mask [-1, 5, -1, -1]
|
||||||
|
for (size_t i = 0; i < inputs_shape_[2].size(); ++i) {
|
||||||
|
if (i == 1) {
|
||||||
|
input_tensor_map_local_mask.push_back((int64_t)(inputs_shape_[2].size() + 3 - 2));
|
||||||
|
} else {
|
||||||
|
input_tensor_map_local_mask.push_back((int64_t)(MAP_NONE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TensorMap input_tensor_map_global_mask;
|
||||||
|
// input_tensor_map_local_mask [5, -1, -1, -1]
|
||||||
|
for (size_t i = 0; i < inputs_shape_[3].size(); ++i) {
|
||||||
|
if (i == 0) {
|
||||||
|
input_tensor_map_global_mask.push_back((int64_t)(inputs_shape_[3].size() + 3 - 2));
|
||||||
|
} else {
|
||||||
|
input_tensor_map_global_mask.push_back((int64_t)(MAP_NONE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TensorMap output_tensor_map_local_prob;
|
||||||
|
// output_tensor_map_local_prob [5, 6, -1, -1, -1, -1, -1]
|
||||||
|
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) {
|
||||||
|
if (i == 0) {
|
||||||
|
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2));
|
||||||
|
} else if (i == 1) {
|
||||||
|
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1));
|
||||||
|
} else {
|
||||||
|
output_tensor_map_local_prob.push_back((int64_t)(MAP_NONE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TensorMap output_tensor_map_global_prob;
|
||||||
|
// output_tensor_map_global_prob [5, 6, -1, -1, -1, -1, -1]
|
||||||
|
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) {
|
||||||
|
if (i == 0) {
|
||||||
|
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2));
|
||||||
|
} else if (i == 1) {
|
||||||
|
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1));
|
||||||
|
} else {
|
||||||
|
output_tensor_map_global_prob.push_back((int64_t)(MAP_NONE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs_tensor_map_.push_back(input_tensor_map_q);
|
||||||
|
inputs_tensor_map_.push_back(input_tensor_map_k);
|
||||||
|
inputs_tensor_map_.push_back(input_tensor_map_local_mask);
|
||||||
|
inputs_tensor_map_.push_back(input_tensor_map_global_mask);
|
||||||
|
outputs_tensor_map_.push_back(output_tensor_map_local_prob);
|
||||||
|
outputs_tensor_map_.push_back(output_tensor_map_global_prob);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MatmulDDSInfo::GetAttrs() {
|
||||||
|
if ((inputs_shape_.size() != MATMUL_DDS_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_DDS_OUTPUTS_SIZE)) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size "
|
||||||
|
<< outputs_shape_.size() << " is wrong.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto iter = attrs_.find(BS);
|
||||||
|
if (iter != attrs_.end()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(iter->second);
|
||||||
|
if (iter->second->isa<Int64Imm>()) {
|
||||||
|
batch_size_ = iter->second->cast<Int64ImmPtr>()->value();
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The value of axis is not int64_t.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iter = attrs_.find(HEADS);
|
||||||
|
if (iter != attrs_.end()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(iter->second);
|
||||||
|
if (iter->second->isa<Int64Imm>()) {
|
||||||
|
num_heads_ = iter->second->cast<Int64ImmPtr>()->value();
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The value of axis is not int64_t.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MatmulDDSInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
|
int64_t new_batch_size;
|
||||||
|
int64_t new_num_heads;
|
||||||
|
int64_t batch_shard_num = strategy_->GetInputDim()[0][1];
|
||||||
|
int64_t heads_shard_num = strategy_->GetInputDim()[0][0];
|
||||||
|
new_batch_size = batch_size_ / batch_shard_num;
|
||||||
|
new_num_heads = num_heads_ / heads_shard_num;
|
||||||
|
ValuePtr new_bs_value = MakeValue(new_batch_size);
|
||||||
|
ValuePtr new_heads_value = MakeValue(new_num_heads);
|
||||||
|
Attr attr_batch_size = std::make_pair(BS, new_bs_value);
|
||||||
|
Attr attr_num_heads = std::make_pair(HEADS, new_heads_value);
|
||||||
|
OperatorAttrs attrs = {attr_batch_size, attr_num_heads};
|
||||||
|
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||||
|
if (gen_g.Init(cnode) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto matmul_dds_node =
|
||||||
|
gen_g.PushBack({gen_g.NewOpInst(MATMUL_DDS, attrs), gen_g.virtual_input_node(), gen_g.virtual_input_node(),
|
||||||
|
gen_g.virtual_input_node(), gen_g.virtual_input_node()});
|
||||||
|
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {
|
||||||
|
std::make_pair(matmul_dds_node, 1),
|
||||||
|
std::make_pair(matmul_dds_node, 2),
|
||||||
|
std::make_pair(matmul_dds_node, 3),
|
||||||
|
std::make_pair(matmul_dds_node, 4),
|
||||||
|
};
|
||||||
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||||
|
std::make_pair(input_nodes, matmul_dds_node));
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
ReplaceGraphPtr MatmulDDSInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
|
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||||
|
}
|
||||||
|
return replace_graph_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MatmulDDSInfo::Init(const StrategyPtr &strategy) {
|
||||||
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": Init success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MatmulDDSInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||||
|
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> MatmulDDSInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
// to generate the first input's strategy
|
||||||
|
Shape input0_split = {1, 1, 0, 0};
|
||||||
|
Shapes splittable_input = {input0_split};
|
||||||
|
Shapes tmp_inputs_shape = {inputs_shape_[0]};
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
// the others strategies are set by the first input's strategy
|
||||||
|
for (auto &sp : sp_vector) {
|
||||||
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||||
|
}
|
||||||
|
Strategys tmp_strategy;
|
||||||
|
Dimensions q_strategy = sp->GetInputDim()[0];
|
||||||
|
Dimensions k_strategy = q_strategy;
|
||||||
|
Dimensions local_mask_strategy = {1, q_strategy[0], 1, 1};
|
||||||
|
Dimensions global_mask_strategy = {q_strategy[0], 1, 1, 1};
|
||||||
|
|
||||||
|
tmp_strategy.push_back(q_strategy); // q
|
||||||
|
tmp_strategy.push_back(k_strategy); // k
|
||||||
|
tmp_strategy.push_back(local_mask_strategy); // local_mask
|
||||||
|
tmp_strategy.push_back(global_mask_strategy); // global_mask
|
||||||
|
sp->ResetInputs(tmp_strategy);
|
||||||
|
}
|
||||||
|
return sp_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MatmulDDSInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_DDS_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_DDS_INFO_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "frontend/parallel/ops_info/operator_info.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
/*
|
||||||
|
* parallel class for MatmulDDS Primitive
|
||||||
|
*/
|
||||||
|
class MatmulDDSInfo : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
MatmulDDSInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatmulDDSCost>()) {}
|
||||||
|
~MatmulDDSInfo() override = default;
|
||||||
|
Status Init(const StrategyPtr &strategy) override;
|
||||||
|
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
Status GetAttrs() override;
|
||||||
|
Status InferAsLossDivisor() override { return SUCCESS; }
|
||||||
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Dimensions input_strategy_;
|
||||||
|
int64_t batch_size_ = 0;
|
||||||
|
int64_t num_heads_ = 0;
|
||||||
|
};
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_DDS_INFO_H_
|
|
@ -59,5 +59,6 @@
|
||||||
#include "frontend/parallel/ops_info/batchnorm_info.h"
|
#include "frontend/parallel/ops_info/batchnorm_info.h"
|
||||||
#include "frontend/parallel/ops_info/maxpool_info.h"
|
#include "frontend/parallel/ops_info/maxpool_info.h"
|
||||||
#include "frontend/parallel/ops_info/gatherd_info.h"
|
#include "frontend/parallel/ops_info/gatherd_info.h"
|
||||||
|
#include "frontend/parallel/ops_info/matmul_dds_info.h"
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||||
|
|
|
@ -74,6 +74,9 @@ constexpr size_t TRANSFER_CONCAT_TENSOR_DIM_INDEX = 0;
|
||||||
constexpr size_t TRANSFER_CONCAT_DEV_DIM_INDEX = 1;
|
constexpr size_t TRANSFER_CONCAT_DEV_DIM_INDEX = 1;
|
||||||
constexpr size_t TRANSFER_CONCAT_SPLIT_COUNT_INDEX = 2;
|
constexpr size_t TRANSFER_CONCAT_SPLIT_COUNT_INDEX = 2;
|
||||||
constexpr size_t TRANSFER_SPLIT_ARGS_SIZE = 3;
|
constexpr size_t TRANSFER_SPLIT_ARGS_SIZE = 3;
|
||||||
|
constexpr size_t MATMUL_DDS_INPUTS_SIZE = 4;
|
||||||
|
constexpr size_t MATMUL_DDS_OUTPUTS_SIZE = 2;
|
||||||
|
constexpr size_t MATMUL_DDS_STRATEGY_SIZE = 4;
|
||||||
constexpr double EPS = 1e-6;
|
constexpr double EPS = 1e-6;
|
||||||
constexpr double INF = 1e20;
|
constexpr double INF = 1e20;
|
||||||
constexpr double COST_FACTOR = 2.0;
|
constexpr double COST_FACTOR = 2.0;
|
||||||
|
@ -142,6 +145,8 @@ constexpr char RANGE_MAX[] = "range_max";
|
||||||
constexpr char REMOVE_ACCIDENTAL_HITS[] = "remove_accidental_hits";
|
constexpr char REMOVE_ACCIDENTAL_HITS[] = "remove_accidental_hits";
|
||||||
constexpr char UNIQUE_STRING[] = "unique";
|
constexpr char UNIQUE_STRING[] = "unique";
|
||||||
constexpr char AXIS[] = "axis";
|
constexpr char AXIS[] = "axis";
|
||||||
|
constexpr char BS[] = "bs";
|
||||||
|
constexpr char HEADS[] = "heads";
|
||||||
constexpr char AXES[] = "axes";
|
constexpr char AXES[] = "axes";
|
||||||
constexpr char START[] = "start";
|
constexpr char START[] = "start";
|
||||||
constexpr char LIMIT[] = "limit";
|
constexpr char LIMIT[] = "limit";
|
||||||
|
@ -319,6 +324,7 @@ constexpr char BROADCAST_TO[] = "BroadcastTo";
|
||||||
constexpr char SQRT[] = "Sqrt";
|
constexpr char SQRT[] = "Sqrt";
|
||||||
constexpr char ASSIGN[] = "Assign";
|
constexpr char ASSIGN[] = "Assign";
|
||||||
constexpr char GET_NEXT[] = "GetNext";
|
constexpr char GET_NEXT[] = "GetNext";
|
||||||
|
constexpr char MATMUL_DDS[] = "MatmulDDS";
|
||||||
constexpr char SQUEEZE[] = "Squeeze";
|
constexpr char SQUEEZE[] = "Squeeze";
|
||||||
constexpr char NEG[] = "Neg";
|
constexpr char NEG[] = "Neg";
|
||||||
constexpr char ABS[] = "Abs";
|
constexpr char ABS[] = "Abs";
|
||||||
|
|
|
@ -45,7 +45,7 @@ class PReLUInfo : public OperatorInfo {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
Status InferForwardCommunication() { return SUCCESS; }
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
|
|
|
@ -170,7 +170,8 @@ bool IsSplittableOperator(const std::string &op_name) {
|
||||||
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
|
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
|
||||||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
||||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE};
|
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||||
|
MATMUL_DDS};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
auto iter = splittable_op.find(op_name);
|
auto iter = splittable_op.find(op_name);
|
||||||
|
|
|
@ -0,0 +1,169 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common.api import _cell_graph_executor
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.parallel import set_algo_parameters
|
||||||
|
from mindspore.ops.operations._inner_ops import MatmulDDS
|
||||||
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
grad_all = C.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
# q: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
|
||||||
|
# k: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
|
||||||
|
# local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16)
|
||||||
|
# global_mask: (bs * global_size // 16, seq_len // 16, 16, 16)
|
||||||
|
# local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16)
|
||||||
|
# global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16)
|
||||||
|
# x: (bs*seq_len, num_heads*size_per_head)
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, batch_size, num_heads, dp, mp, shard=True):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.size_per_head = 128
|
||||||
|
self.seq_len = 1024
|
||||||
|
self.block_size = 64
|
||||||
|
self.block_num = self.seq_len // self.block_size
|
||||||
|
self.global_size = 256
|
||||||
|
self.embedding_size = num_heads * self.size_per_head
|
||||||
|
self.cus_matmul = MatmulDDS(batch_size, num_heads)
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.global_mask = Tensor(np.ones((batch_size * self.global_size // 16, self.seq_len // 16, 16, 16)))
|
||||||
|
self.local_mask = Tensor(np.ones((self.block_num * self.block_size // 16,
|
||||||
|
batch_size * self.block_size // 16, 16, 16)))
|
||||||
|
self.dense1 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
|
||||||
|
self.dense2 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.add = P.Add()
|
||||||
|
if shard:
|
||||||
|
self.cus_matmul.shard(((mp, dp, 1, 1), (mp, dp, 1, 1), (1, dp, 1, 1), (dp, 1, 1, 1)))
|
||||||
|
self.dense1.matmul.shard(((dp, 1), (mp, 1)))
|
||||||
|
self.dense2.matmul.shard(((dp, 1), (mp, 1)))
|
||||||
|
self.transpose.shard(((dp, 1, mp, 1),))
|
||||||
|
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
q = self.dense1(x)
|
||||||
|
k = self.dense2(x)
|
||||||
|
q = self.transpose(self.reshape(q, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
|
||||||
|
k = self.transpose(self.reshape(k, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
|
||||||
|
local_prob, global_prob = self.cus_matmul(q, k, self.local_mask, self.global_mask)
|
||||||
|
local_prob = self.reshape(local_prob, (self.batch_size, self.num_heads, -1))
|
||||||
|
global_prob = self.reshape(global_prob, (self.batch_size, self.num_heads, -1))
|
||||||
|
local_prob_reduce = self.reduce_sum(local_prob, 2)
|
||||||
|
global_prob_reduce = self.reduce_sum(global_prob, 2)
|
||||||
|
result = self.add(local_prob_reduce, global_prob_reduce)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return grad_all(self.network)(x)
|
||||||
|
|
||||||
|
|
||||||
|
class NetWithLoss(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(NetWithLoss, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.loss = VirtualLoss()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
predict = self.network(x)
|
||||||
|
return self.loss(predict)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
|
||||||
|
if auto:
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
else:
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
|
||||||
|
net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
net.set_train()
|
||||||
|
_cell_graph_executor.compile(net, x)
|
||||||
|
|
||||||
|
def test_cus_matmul_dds_model_parallel_mix():
|
||||||
|
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||||
|
batch_size = 128
|
||||||
|
num_heads = 32
|
||||||
|
dp = 2
|
||||||
|
mp = 8
|
||||||
|
compile_graph(batch_size, num_heads, dp, mp)
|
||||||
|
|
||||||
|
def test_cus_matmul_dds_model_parallel_dp():
|
||||||
|
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||||
|
batch_size = 128
|
||||||
|
num_heads = 32
|
||||||
|
dp = 16
|
||||||
|
mp = 1
|
||||||
|
compile_graph(batch_size, num_heads, dp, mp)
|
||||||
|
|
||||||
|
def test_cus_matmul_dds_model_parallel_mp():
|
||||||
|
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||||
|
batch_size = 128
|
||||||
|
num_heads = 32
|
||||||
|
dp = 1
|
||||||
|
mp = 16
|
||||||
|
compile_graph(batch_size, num_heads, dp, mp)
|
||||||
|
|
||||||
|
def test_cus_matmul_dds_model_parallel_mix_auto():
|
||||||
|
set_algo_parameters(fully_use_devices=False)
|
||||||
|
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||||
|
batch_size = 128
|
||||||
|
num_heads = 32
|
||||||
|
dp = 2
|
||||||
|
mp = 8
|
||||||
|
compile_graph(batch_size, num_heads, dp, mp, auto=True)
|
||||||
|
|
||||||
|
def test_cus_matmul_dds_model_parallel_dp_auto():
|
||||||
|
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||||
|
batch_size = 128
|
||||||
|
num_heads = 32
|
||||||
|
dp = 16
|
||||||
|
mp = 1
|
||||||
|
compile_graph(batch_size, num_heads, dp, mp, auto=True)
|
||||||
|
|
||||||
|
def test_cus_matmul_dds_model_parallel_mp_auto():
|
||||||
|
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||||
|
batch_size = 128
|
||||||
|
num_heads = 32
|
||||||
|
dp = 1
|
||||||
|
mp = 16
|
||||||
|
compile_graph(batch_size, num_heads, dp, mp, auto=True)
|
||||||
|
|
||||||
|
def test_cus_matmul_dds_model_parallel_auto():
|
||||||
|
set_algo_parameters(fully_use_devices=False)
|
||||||
|
context.set_auto_parallel_context(device_num=16, global_rank=0)
|
||||||
|
batch_size = 128
|
||||||
|
num_heads = 32
|
||||||
|
dp = 1
|
||||||
|
mp = 16
|
||||||
|
compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
|
Loading…
Reference in New Issue