!22096 add parallel sparse attention ops

Merge pull request !22096 from yao_yf/parallel_sparse_attention_ops
This commit is contained in:
i-robot 2021-08-30 10:43:05 +00:00 committed by Gitee
commit 5b8f9ff398
11 changed files with 633 additions and 3 deletions

View File

@ -1744,5 +1744,35 @@ void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &)
is_inputs_should_in_memory_[i] = is_parameter_[i];
}
}
double MatmulDDSCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t) const {
double result = 0.0;
if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost";
}
for (size_t index = 0; index < inputs.size(); ++index) {
TensorInfo tensor_info = inputs[index];
Shape slice_shape = tensor_info.slice_shape();
result += ListProduct(slice_shape) * static_cast<double>(inputs_type_lengths_[index]);
}
return result;
}
// Not taking account of output
void MatmulDDSCost::CalculateOutputInMemory() {
is_output_should_in_memory_ =
(std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
}
// Taking account of input
void MatmulDDSCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) {
bool keep_mem =
(std::find(is_parameter_.begin(), is_parameter_.end(), true) != is_parameter_.end()) ||
(std::find(is_parameter_involve_.begin(), is_parameter_involve_.end(), true) != is_parameter_involve_.end());
std::fill(is_inputs_should_in_memory_.begin(), is_inputs_should_in_memory_.end(), keep_mem);
}
} // namespace parallel
} // namespace mindspore

View File

@ -1051,6 +1051,44 @@ class GatherV2PCost : public GatherV2Cost {
int64_t axis_;
Shape strategy_;
};
class MatmulDDSCost : public OperatorCost {
public:
MatmulDDSCost() : OperatorCost() {}
~MatmulDDSCost() override = default;
// per device communication cost
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
}
double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return 0.0;
};
double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return 0.0;
};
// per device computation cost
double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
}
double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override {
return 0.0;
};
// Not taking account of output
void CalculateOutputInMemory() override;
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using MatmulDDSCostPtr = std::shared_ptr<MatmulDDSCost>;
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_

View File

@ -204,6 +204,7 @@ REGISTER(MaxPoolInfo);
REGISTER(AvgPoolInfo);
REGISTER(GatherDInfo);
REGISTER(ReduceAnyInfo);
REGISTER(MatmulDDSInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -165,7 +165,7 @@ AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
}
(void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0]
(void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[1]
auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
return new_anf_node_ptr;

View File

@ -0,0 +1,320 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/matmul_dds_info.h"
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace parallel {
/*
* MatmulDDS has 4 input
* q, k: A 4D float used in transformer model,
* the shape is [num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16], num_heads*size_per_head = embedding_size
* The shape is reshaped for cube. origin shape is
* (bs*seq_len, embedding_size) <=> (bs, num_heads, seq_len, size_per_head)
* local_mask: Local mask in sparse attention, the shape is
* (seq_len // 16, bs * block_size // 16, 16, 16).
* block_num = seq_len // block_size, block_size = 64, always.
* global_mask: Global mask in sparse attention,
* the shape is (bs * global_size // 16, seq_len // 16, 16, 16)
* seq_len = 1024, global_size = 256, always.
* Only bs and num_heads can be splited, thus the q[0] should at least be size_per_head,
* q[1] should at least be seq_len // 16. The strategy check can use bs/head from attrs.
*/
Status MatmulDDSInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}
Strategys stras = strategy->GetInputDim();
if (stras.size() != MATMUL_DDS_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys size should be 4.";
return FAILED;
}
for (auto stra : stras) {
if (stra.size() != MATMUL_DDS_STRATEGY_SIZE) {
MS_LOG(ERROR) << name_
<< ": Invalid strategy. The strategy size should be 4, but in current dim, "
"the strategy is"
<< stra;
return FAILED;
}
}
if (stras[0][0] != stras[1][0] || num_heads_ % stras[0][0] != 0) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][0]:" << stras[0][0]
<< " should be equal to strategys[1][0]:" << stras[1][0]
<< " ,and should be divisible by num_heads: " << num_heads_;
return FAILED;
}
if (stras[0][1] != stras[1][1] || stras[0][1] != stras[2][1] || stras[0][1] != stras[3][0]) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][1]:" << stras[0][1]
<< ", strategys[1][1]:" << stras[1][1] << ", strategys[2][1]:" << stras[2][1]
<< ", strategys[3][0]:" << stras[3][0] << " should be the same.";
return FAILED;
}
if (batch_size_ % stras[0][1] != 0) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][1]:" << stras[0][1]
<< " should be divisible by batch_sizes:" << batch_size_;
return FAILED;
}
for (size_t i = 2; i < stras[0].size(); ++i) {
if (stras[0][i] != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[0][" << i << "] only support 1";
return FAILED;
}
}
for (size_t i = 2; i < stras[1].size(); ++i) {
if (stras[1][i] != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[1][" << i << "] only support 1";
return FAILED;
}
}
for (size_t i = 0; i < stras[2].size(); ++i) {
if (i != 1 && stras[2][i] != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[2][" << i << "] only support 1";
return FAILED;
}
}
for (size_t i = 1; i < stras[3].size(); ++i) {
if (stras[3][i] != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. The strategys[3][" << i << "] only support 1";
return FAILED;
}
}
return SUCCESS;
}
/*
* device matrix is extended by the strategy0.
*/
Status MatmulDDSInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0);
input_strategy_ = input_strategy;
dev_matrix_shape_ = input_strategy;
dev_matrix_shape_.push_back(1);
dev_matrix_shape_.push_back(1);
dev_matrix_shape_.push_back(1);
return SUCCESS;
}
/*
* q: [num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16]
* k: [num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16]
* local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16)
* global_mask: (bs * global_size // 16, seq_len // 16, 16, 16)
* local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16)
* global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16)
* device_matrix: [num_heads_stra, bs_stra, 1, 1, 1, 1, 1]
*/
Status MatmulDDSInfo::InferTensorMap() {
TensorMap input_tensor_map_q;
// input_tensor_map_q [6, 5, -1, -1]
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
if (i <= 1) {
input_tensor_map_q.push_back((int64_t)(inputs_shape_[0].size() + 3 - i - 1));
} else {
input_tensor_map_q.push_back((int64_t)(MAP_NONE));
}
}
TensorMap input_tensor_map_k;
// input_tensor_map_k [6, 5, -1, -1]
for (size_t i = 0; i < inputs_shape_[1].size(); ++i) {
if (i <= 1) {
input_tensor_map_k.push_back((int64_t)(inputs_shape_[1].size() + 3 - i - 1));
} else {
input_tensor_map_k.push_back((int64_t)(MAP_NONE));
}
}
TensorMap input_tensor_map_local_mask;
// input_tensor_map_local_mask [-1, 5, -1, -1]
for (size_t i = 0; i < inputs_shape_[2].size(); ++i) {
if (i == 1) {
input_tensor_map_local_mask.push_back((int64_t)(inputs_shape_[2].size() + 3 - 2));
} else {
input_tensor_map_local_mask.push_back((int64_t)(MAP_NONE));
}
}
TensorMap input_tensor_map_global_mask;
// input_tensor_map_local_mask [5, -1, -1, -1]
for (size_t i = 0; i < inputs_shape_[3].size(); ++i) {
if (i == 0) {
input_tensor_map_global_mask.push_back((int64_t)(inputs_shape_[3].size() + 3 - 2));
} else {
input_tensor_map_global_mask.push_back((int64_t)(MAP_NONE));
}
}
TensorMap output_tensor_map_local_prob;
// output_tensor_map_local_prob [5, 6, -1, -1, -1, -1, -1]
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) {
if (i == 0) {
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2));
} else if (i == 1) {
output_tensor_map_local_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1));
} else {
output_tensor_map_local_prob.push_back((int64_t)(MAP_NONE));
}
}
TensorMap output_tensor_map_global_prob;
// output_tensor_map_global_prob [5, 6, -1, -1, -1, -1, -1]
for (size_t i = 0; i < dev_matrix_shape_.size(); ++i) {
if (i == 0) {
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 2));
} else if (i == 1) {
output_tensor_map_global_prob.push_back((int64_t)(dev_matrix_shape_.size() - 1));
} else {
output_tensor_map_global_prob.push_back((int64_t)(MAP_NONE));
}
}
inputs_tensor_map_.push_back(input_tensor_map_q);
inputs_tensor_map_.push_back(input_tensor_map_k);
inputs_tensor_map_.push_back(input_tensor_map_local_mask);
inputs_tensor_map_.push_back(input_tensor_map_global_mask);
outputs_tensor_map_.push_back(output_tensor_map_local_prob);
outputs_tensor_map_.push_back(output_tensor_map_global_prob);
return SUCCESS;
}
Status MatmulDDSInfo::GetAttrs() {
if ((inputs_shape_.size() != MATMUL_DDS_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_DDS_OUTPUTS_SIZE)) {
MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size "
<< outputs_shape_.size() << " is wrong.";
return FAILED;
}
auto iter = attrs_.find(BS);
if (iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<Int64Imm>()) {
batch_size_ = iter->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << ": The value of axis is not int64_t.";
return FAILED;
}
}
iter = attrs_.find(HEADS);
if (iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<Int64Imm>()) {
num_heads_ = iter->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << ": The value of axis is not int64_t.";
return FAILED;
}
}
return SUCCESS;
}
Status MatmulDDSInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
int64_t new_batch_size;
int64_t new_num_heads;
int64_t batch_shard_num = strategy_->GetInputDim()[0][1];
int64_t heads_shard_num = strategy_->GetInputDim()[0][0];
new_batch_size = batch_size_ / batch_shard_num;
new_num_heads = num_heads_ / heads_shard_num;
ValuePtr new_bs_value = MakeValue(new_batch_size);
ValuePtr new_heads_value = MakeValue(new_num_heads);
Attr attr_batch_size = std::make_pair(BS, new_bs_value);
Attr attr_num_heads = std::make_pair(HEADS, new_heads_value);
OperatorAttrs attrs = {attr_batch_size, attr_num_heads};
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
auto matmul_dds_node =
gen_g.PushBack({gen_g.NewOpInst(MATMUL_DDS, attrs), gen_g.virtual_input_node(), gen_g.virtual_input_node(),
gen_g.virtual_input_node(), gen_g.virtual_input_node()});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {
std::make_pair(matmul_dds_node, 1),
std::make_pair(matmul_dds_node, 2),
std::make_pair(matmul_dds_node, 3),
std::make_pair(matmul_dds_node, 4),
};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, matmul_dds_node));
return SUCCESS;
}
ReplaceGraphPtr MatmulDDSInfo::replace_graph(const CNodePtr &cnode) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
return replace_graph_;
}
Status MatmulDDSInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status MatmulDDSInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
std::vector<StrategyPtr> MatmulDDSInfo::GenerateOpStrategies(int64_t stage_id) {
// to generate the first input's strategy
Shape input0_split = {1, 1, 0, 0};
Shapes splittable_input = {input0_split};
Shapes tmp_inputs_shape = {inputs_shape_[0]};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
}
// the others strategies are set by the first input's strategy
for (auto &sp : sp_vector) {
if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
}
Strategys tmp_strategy;
Dimensions q_strategy = sp->GetInputDim()[0];
Dimensions k_strategy = q_strategy;
Dimensions local_mask_strategy = {1, q_strategy[0], 1, 1};
Dimensions global_mask_strategy = {q_strategy[0], 1, 1, 1};
tmp_strategy.push_back(q_strategy); // q
tmp_strategy.push_back(k_strategy); // k
tmp_strategy.push_back(local_mask_strategy); // local_mask
tmp_strategy.push_back(global_mask_strategy); // global_mask
sp->ResetInputs(tmp_strategy);
}
return sp_vector;
}
Status MatmulDDSInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_DDS_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_DDS_INFO_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
/*
* parallel class for MatmulDDS Primitive
*/
class MatmulDDSInfo : public OperatorInfo {
public:
MatmulDDSInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatmulDDSCost>()) {}
~MatmulDDSInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
Status InferAsLossDivisor() override { return SUCCESS; }
Status ComputeReplaceGraph(const CNodePtr &cnode);
private:
Dimensions input_strategy_;
int64_t batch_size_ = 0;
int64_t num_heads_ = 0;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MATMUL_DDS_INFO_H_

View File

@ -59,5 +59,6 @@
#include "frontend/parallel/ops_info/batchnorm_info.h"
#include "frontend/parallel/ops_info/maxpool_info.h"
#include "frontend/parallel/ops_info/gatherd_info.h"
#include "frontend/parallel/ops_info/matmul_dds_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -74,6 +74,9 @@ constexpr size_t TRANSFER_CONCAT_TENSOR_DIM_INDEX = 0;
constexpr size_t TRANSFER_CONCAT_DEV_DIM_INDEX = 1;
constexpr size_t TRANSFER_CONCAT_SPLIT_COUNT_INDEX = 2;
constexpr size_t TRANSFER_SPLIT_ARGS_SIZE = 3;
constexpr size_t MATMUL_DDS_INPUTS_SIZE = 4;
constexpr size_t MATMUL_DDS_OUTPUTS_SIZE = 2;
constexpr size_t MATMUL_DDS_STRATEGY_SIZE = 4;
constexpr double EPS = 1e-6;
constexpr double INF = 1e20;
constexpr double COST_FACTOR = 2.0;
@ -142,6 +145,8 @@ constexpr char RANGE_MAX[] = "range_max";
constexpr char REMOVE_ACCIDENTAL_HITS[] = "remove_accidental_hits";
constexpr char UNIQUE_STRING[] = "unique";
constexpr char AXIS[] = "axis";
constexpr char BS[] = "bs";
constexpr char HEADS[] = "heads";
constexpr char AXES[] = "axes";
constexpr char START[] = "start";
constexpr char LIMIT[] = "limit";
@ -319,6 +324,7 @@ constexpr char BROADCAST_TO[] = "BroadcastTo";
constexpr char SQRT[] = "Sqrt";
constexpr char ASSIGN[] = "Assign";
constexpr char GET_NEXT[] = "GetNext";
constexpr char MATMUL_DDS[] = "MatmulDDS";
constexpr char SQUEEZE[] = "Squeeze";
constexpr char NEG[] = "Neg";
constexpr char ABS[] = "Abs";

View File

@ -45,7 +45,7 @@ class PReLUInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;

View File

@ -170,7 +170,8 @@ bool IsSplittableOperator(const std::string &op_name) {
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE};
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
MATMUL_DDS};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -0,0 +1,169 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.parallel import set_algo_parameters
from mindspore.ops.operations._inner_ops import MatmulDDS
from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
# q: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
# k: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
# local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16)
# global_mask: (bs * global_size // 16, seq_len // 16, 16, 16)
# local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16)
# global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16)
# x: (bs*seq_len, num_heads*size_per_head)
class Net(nn.Cell):
def __init__(self, batch_size, num_heads, dp, mp, shard=True):
super(Net, self).__init__()
self.batch_size = batch_size
self.num_heads = num_heads
self.size_per_head = 128
self.seq_len = 1024
self.block_size = 64
self.block_num = self.seq_len // self.block_size
self.global_size = 256
self.embedding_size = num_heads * self.size_per_head
self.cus_matmul = MatmulDDS(batch_size, num_heads)
self.reduce_sum = P.ReduceSum()
self.global_mask = Tensor(np.ones((batch_size * self.global_size // 16, self.seq_len // 16, 16, 16)))
self.local_mask = Tensor(np.ones((self.block_num * self.block_size // 16,
batch_size * self.block_size // 16, 16, 16)))
self.dense1 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
self.dense2 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.add = P.Add()
if shard:
self.cus_matmul.shard(((mp, dp, 1, 1), (mp, dp, 1, 1), (1, dp, 1, 1), (dp, 1, 1, 1)))
self.dense1.matmul.shard(((dp, 1), (mp, 1)))
self.dense2.matmul.shard(((dp, 1), (mp, 1)))
self.transpose.shard(((dp, 1, mp, 1),))
def construct(self, x):
q = self.dense1(x)
k = self.dense2(x)
q = self.transpose(self.reshape(q, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
k = self.transpose(self.reshape(k, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
local_prob, global_prob = self.cus_matmul(q, k, self.local_mask, self.global_mask)
local_prob = self.reshape(local_prob, (self.batch_size, self.num_heads, -1))
global_prob = self.reshape(global_prob, (self.batch_size, self.num_heads, -1))
local_prob_reduce = self.reduce_sum(local_prob, 2)
global_prob_reduce = self.reduce_sum(global_prob, 2)
result = self.add(local_prob_reduce, global_prob_reduce)
return result
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x):
return grad_all(self.network)(x)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.network = network
self.loss = VirtualLoss()
def construct(self, x):
predict = self.network(x)
return self.loss(predict)
def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")
else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x)
def test_cus_matmul_dds_model_parallel_mix():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 2
mp = 8
compile_graph(batch_size, num_heads, dp, mp)
def test_cus_matmul_dds_model_parallel_dp():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 16
mp = 1
compile_graph(batch_size, num_heads, dp, mp)
def test_cus_matmul_dds_model_parallel_mp():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 1
mp = 16
compile_graph(batch_size, num_heads, dp, mp)
def test_cus_matmul_dds_model_parallel_mix_auto():
set_algo_parameters(fully_use_devices=False)
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 2
mp = 8
compile_graph(batch_size, num_heads, dp, mp, auto=True)
def test_cus_matmul_dds_model_parallel_dp_auto():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 16
mp = 1
compile_graph(batch_size, num_heads, dp, mp, auto=True)
def test_cus_matmul_dds_model_parallel_mp_auto():
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 1
mp = 16
compile_graph(batch_size, num_heads, dp, mp, auto=True)
def test_cus_matmul_dds_model_parallel_auto():
set_algo_parameters(fully_use_devices=False)
context.set_auto_parallel_context(device_num=16, global_rank=0)
batch_size = 128
num_heads = 32
dp = 1
mp = 16
compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)