!18884 split-with-over-lap support fp16

Merge pull request !18884 from ling/sr
This commit is contained in:
i-robot 2021-06-26 01:55:08 +00:00 committed by Gitee
commit 69cedc28f4
4 changed files with 124 additions and 78 deletions

View File

@ -26,44 +26,67 @@ using mindspore::schema::PrimitiveType_SplitWithOverlap;
namespace mindspore::kernel {
void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const SplitWithOverlapParameter *param,
const std::vector<int> &shape) {
void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const std::vector<int> &shape) {
int total_block_count = 0;
for (auto i = 0; i < param->num_split_; i++) {
total_block_count += param->ratio_[i];
for (auto i = 0; i < param_->num_split_; i++) {
total_block_count += param_->ratio_[i];
}
auto split_dim_size = shape[param->split_dim_];
auto split_dim_size = shape[param_->split_dim_];
std::vector<int> borders;
borders.emplace_back(0);
int visited_block = 0;
for (auto i = 0; i < param->num_split_ - 1; i++) {
visited_block += param->ratio_[i];
for (auto i = 0; i < param_->num_split_ - 1; i++) {
visited_block += param_->ratio_[i];
auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
if (param->split_stride_ != 0) {
if (param_->split_stride_ != 0) {
// make sure border align with stride
cur_border = UP_ROUND(cur_border + param->pad_top_, param->split_stride_);
borders.emplace_back(cur_border - param->pad_top_);
cur_border = UP_ROUND(cur_border + param_->pad_top_, param_->split_stride_);
borders.emplace_back(cur_border - param_->pad_top_);
} else {
borders.emplace_back(cur_border);
}
}
borders.emplace_back(split_dim_size);
for (auto i = 0; i < param->num_split_; i++) {
for (auto i = 0; i < param_->num_split_; i++) {
start_indices_.emplace_back(borders[i]);
end_indices_.emplace_back(borders[i + 1]);
// overlap: calibrate start_indices and end_indices by adding extends
start_indices_[i] -= param->extend_top_[i];
end_indices_[i] += param->extend_bottom_[i];
start_indices_[i] -= param_->extend_top_[i];
end_indices_[i] += param_->extend_bottom_[i];
}
}
int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; }
int SplitWithOverlapBaseCPUKernel::Init() {
MS_ASSERT(param_->num_split_ > 1);
return ReSize();
}
int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; }
int SplitWithOverlapBaseCPUKernel::ReSize() {
auto in_tensor = in_tensors_.front();
auto input_shape = in_tensor->shape();
start_indices_.clear();
end_indices_.clear();
CalculateSplitedShapes(input_shape);
element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
outer_total_dim_ = 1;
inner_stride_ = 1;
for (int i = 0; i < static_cast<int>(input_shape.size()); i++) {
if (i < param_->split_dim_) outer_total_dim_ *= input_shape[i];
if (i == param_->split_dim_) split_dim_size_ = input_shape[param_->split_dim_];
if (i > param_->split_dim_) inner_stride_ *= input_shape[i];
}
return RET_OK;
}
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
@ -84,39 +107,12 @@ int SplitWithOverlapRun(void *cdata, int task_id, float lhs_scale, float rhs_sca
}
int SplitWithOverlapBaseCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! ret: " << prepare_ret;
return prepare_ret;
}
auto in_tensor = in_tensors_.front();
input_ptr_ = reinterpret_cast<char *>(in_tensor->data_c());
auto input_shape = in_tensor->shape();
start_indices_.clear();
end_indices_.clear();
input_ptr_ = reinterpret_cast<char *>(in_tensors_.front()->data_c());
output_ptr_.clear();
for (int i = 0; i < param_->num_split_; i++) {
output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c()));
}
CalculateSplitedShapes(param_, input_shape);
outer_total_dim_ = 1;
inner_stride_ = 1;
split_dim_size_ = input_shape[param_->split_dim_];
element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
for (auto i = 0; i < param_->split_dim_; i++) {
outer_total_dim_ *= input_shape[i];
}
for (int i = static_cast<int>(input_shape.size()) - 1; i > param_->split_dim_; i--) {
inner_stride_ *= input_shape[i];
}
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, param_->num_split_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
@ -128,5 +124,4 @@ int SplitWithOverlapBaseCPUKernel::Run() {
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
} // namespace mindspore::kernel

View File

@ -33,22 +33,25 @@ class SplitWithOverlapBaseCPUKernel : public InnerKernel {
param_ = reinterpret_cast<SplitWithOverlapParameter *>(op_parameter_);
}
~SplitWithOverlapBaseCPUKernel() override = default;
void CalculateSplitedShapes(const SplitWithOverlapParameter *param, const std::vector<int> &shape);
int Init() override;
int ReSize() override;
int Run() override;
int Split(int task_id);
protected:
private:
void CalculateSplitedShapes(const std::vector<int> &shape);
private:
// range: [start, end)
std::vector<int> start_indices_;
std::vector<int> end_indices_;
SplitWithOverlapParameter *param_ = nullptr;
int outer_total_dim_{0};
int inner_stride_{0};
int element_bytes_{0};
int split_dim_size_{0};
SplitWithOverlapParameter *param_ = nullptr;
char *input_ptr_{nullptr};
std::vector<char *> output_ptr_;
};

View File

@ -159,7 +159,7 @@ SearchSubGraph::CostModel SearchSubGraph::CalculateConv2DFusion(Model::Node *nod
} else {
int out_unit;
if (CheckIfUseWinograd(&out_unit, param)) {
size_t winograd_conv_cost = WinogradConvMul();
size_t winograd_conv_cost = CommConvMul(weight_shape, output_shape);
cost.mul_cost_ += winograd_conv_cost;
} else {
size_t comm_conv_mul_cost = CommConvMul(weight_shape, output_shape);
@ -170,7 +170,7 @@ SearchSubGraph::CostModel SearchSubGraph::CalculateConv2DFusion(Model::Node *nod
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
if (CheckConvDw1DWinograd(param, context_->thread_num_)) {
/* ConvolutionDepthwise3x3CPUKernel */
size_t winograd_convdw_cost = WinogradConvDwMul();
size_t winograd_convdw_cost = CommConvdwMul(weight_shape, output_shape);
cost.mul_cost_ += winograd_convdw_cost;
} else {
/* ConvolutionDepthwiseIndirectCPUKernel */
@ -450,18 +450,7 @@ void SearchSubGraph::OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint
VectorErase(&sub.heads_, head_index);
}
/* double check head-end node */
/* head-end node may error after subgraph fusion */
for (uint32_t head_node : sub.heads_) {
if (std::find(sub.nodes_.begin(), sub.nodes_.end(), head_node) == sub.nodes_.end()) {
VectorErase(&sub.nodes_, head_node);
}
}
for (uint32_t end_node : sub.ends_) {
if (std::find(sub.nodes_.begin(), sub.nodes_.end(), end_node) == sub.nodes_.end()) {
VectorErase(&sub.ends_, end_node);
}
}
CheckSubHeadEnd(&sub);
/* sort node index */
std::sort(sub.nodes_.begin(), sub.nodes_.end());
@ -579,9 +568,11 @@ void SearchSubGraph::InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes) {
Model::Node *node = node_list_[node_index];
for (uint32_t input_tensor_index : node->input_indices_) {
Tensor *tensor = &tensors_[input_tensor_index];
if (tensor->type_ == CONST) continue;
if (tensor->type_ == CONST || tensor->type_ == INPUT) continue;
std::vector<uint32_t> input_nodes = tensor->out_nodes_;
if (input_nodes.empty()) continue;
Subgraph sub;
sub.ends_.push_back(input_nodes[0]);
InsertNodeByMid(input_nodes[0], &sub);
@ -602,6 +593,9 @@ void SearchSubGraph::InitSearchSubGraphByMiddle() {
InitMiddleSubgraph(&multy_in_nodes);
if (node_sub_map_.size() > kMaxSubGraphCount) {
node_sub_map_.clear();
}
return;
}
@ -763,6 +757,9 @@ void SearchSubGraph::SubGraphSplitByOutput() {
CalculateCostModel(&sub_graphs_);
InitSubgraphRuntimeInfo(&sub_graphs_);
SubgraphFusion(&sub_graphs_);
for (Subgraph &sub : sub_graphs_) {
CheckSubHeadEnd(&sub);
}
ConvertSubGraphToModel(&sub_graphs_);
}
@ -770,6 +767,9 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
InitSearchSubGraphByMiddle();
for (auto map : node_sub_map_) {
std::vector<Subgraph> &subgraphs = map.second;
if (subgraphs.size() < kDefalutSubGraphSize) {
continue;
}
CalculateCostModel(&subgraphs);
InitSubgraphRuntimeInfo(&subgraphs);
@ -935,17 +935,63 @@ void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
}
}
}
void SearchSubGraph::CheckSubHeadEnd(Subgraph *sub) {
/* head-end node may error after subgraph fusion */
/* sub head node check */
std::vector<uint32_t> delete_head;
for (uint32_t head_node : sub->heads_) {
if (std::find(sub->nodes_.begin(), sub->nodes_.end(), head_node) == sub->nodes_.end()) {
delete_head.push_back(head_node);
continue;
}
Model::Node *node = model_->all_nodes_.at(head_node);
std::vector<uint32_t> in_tensors = node->input_indices_;
std::vector<uint32_t> in_nodes;
for (uint32_t in_t : in_tensors) {
in_nodes.insert(in_nodes.begin(), tensors_.at(in_t).out_nodes_.begin(), tensors_.at(in_t).out_nodes_.end());
}
void SearchSubGraph::InitSearchParallelSubGraph() {
// for every graph output, find a parallel subgraph
for (uint32_t output : *output_nodes_) {
Subgraph subgraph;
InsertParallelNode(output, &subgraph);
sub_graphs_.push_back(std::move(subgraph));
bool erase_head = true;
for (uint32_t in_n : in_nodes) {
if (std::find(sub->nodes_.begin(), sub->nodes_.end(), in_n) == sub->nodes_.end()) {
erase_head = false;
break;
}
}
if (erase_head) {
delete_head.push_back(head_node);
}
}
for (uint32_t head : delete_head) {
VectorErase(&sub->heads_, head);
}
/* sub end node check */
std::vector<uint32_t> delete_end;
for (uint32_t end_node : sub->ends_) {
if (std::find(sub->nodes_.begin(), sub->nodes_.end(), end_node) == sub->nodes_.end()) {
delete_end.push_back(end_node);
}
}
for (uint32_t end : delete_end) {
VectorErase(&sub->ends_, end);
}
return;
}
bool SearchSubGraph::ValidInParallel() {
Model::Node *front_node = model_->all_nodes_.at(0);
if (front_node->quant_type_ != schema::QuantType_QUANT_NONE) {
return false;
}
return true;
}
void SearchSubGraph::SubGraphSplit() {
if (!ValidInParallel()) {
return;
}
UpdateOfflineParallelFlag();
if (offline_parallel_enable_) {
SubGraphSplitByOffLineParallel();

View File

@ -33,6 +33,7 @@ namespace mindspore::lite {
constexpr int kDefaultDeviceType = -1;
constexpr int kDefalutSubGraphSize = 2;
constexpr int kDefaultInputs = 1;
constexpr int kMaxSubGraphCount = 20;
class SearchSubGraph {
enum TensorType { NORMAL, CONST, INPUT };
@ -84,12 +85,12 @@ class SearchSubGraph {
public:
void SubGraphSplit();
private:
private: /* split by output */
void SubGraphSplitByOutput();
void InitSearchSubGraphByOutput();
void InsertNode(uint32_t index, Subgraph *subgraph);
private:
private: /* split by middle */
void SubGraphSplitByMiddle();
void InitSearchSubGraphByMiddle();
void SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes);
@ -97,35 +98,35 @@ class SearchSubGraph {
void InsertNodeByMid(uint32_t node_index, Subgraph *subgraph);
void InsertHeadNode(uint32_t index, Subgraph *subgraph);
void OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index);
std::unordered_map<uint32_t, std::vector<Subgraph>> node_sub_map_;
private:
private: /* split by offline */
void SubGraphSplitByOffLineParallel();
void UpdateOfflineParallelFlag();
bool CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs);
private:
private: /* public graph func */
void RemoveConstNode(std::vector<uint32_t> *nodes);
void InitSearchTensor();
void InitSearchParallelSubGraph();
void InitMainGraphDevice(DeviceType dt = DT_CPU);
void InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs);
void SubgraphFusion(std::vector<Subgraph> *sub_graphs);
void CalculateCostModel(std::vector<Subgraph> *sub_graphs);
void ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs);
bool ValidInParallel();
void CheckSubHeadEnd(Subgraph *sub);
private:
private: /* public schema func */
void InsertParallelNode(uint32_t index, Subgraph *subgraph);
bool IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes);
bool IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
uint32_t root_node_index);
const schema::Primitive *CreatePartialPrimitive(int64_t subgraph_index);
private:
private: /* public cost-model func */
CostModel CalculateConv2DFusion(Model::Node *node);
void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs);
void UpdateOfflineParallelFlag();
bool CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs);
private:
std::vector<size_t> *output_nodes_ = nullptr;
@ -135,6 +136,7 @@ class SearchSubGraph {
LiteModel *model_ = nullptr;
std::vector<Tensor> tensors_;
std::vector<Subgraph> sub_graphs_;
std::unordered_map<uint32_t, std::vector<Subgraph>> node_sub_map_;
std::vector<Model::Node *> node_list_;
DeviceType major_dt_;
DeviceType minor_dt_;