forked from mindspore-Ecosystem/mindspore
!18884 split-with-over-lap support fp16
Merge pull request !18884 from ling/sr
This commit is contained in:
commit
69cedc28f4
|
@ -26,44 +26,67 @@ using mindspore::schema::PrimitiveType_SplitWithOverlap;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const SplitWithOverlapParameter *param,
|
||||
const std::vector<int> &shape) {
|
||||
void SplitWithOverlapBaseCPUKernel::CalculateSplitedShapes(const std::vector<int> &shape) {
|
||||
int total_block_count = 0;
|
||||
for (auto i = 0; i < param->num_split_; i++) {
|
||||
total_block_count += param->ratio_[i];
|
||||
for (auto i = 0; i < param_->num_split_; i++) {
|
||||
total_block_count += param_->ratio_[i];
|
||||
}
|
||||
|
||||
auto split_dim_size = shape[param->split_dim_];
|
||||
auto split_dim_size = shape[param_->split_dim_];
|
||||
|
||||
std::vector<int> borders;
|
||||
borders.emplace_back(0);
|
||||
int visited_block = 0;
|
||||
for (auto i = 0; i < param->num_split_ - 1; i++) {
|
||||
visited_block += param->ratio_[i];
|
||||
for (auto i = 0; i < param_->num_split_ - 1; i++) {
|
||||
visited_block += param_->ratio_[i];
|
||||
auto cur_border = UP_DIV(split_dim_size * visited_block, total_block_count);
|
||||
if (param->split_stride_ != 0) {
|
||||
if (param_->split_stride_ != 0) {
|
||||
// make sure border align with stride
|
||||
cur_border = UP_ROUND(cur_border + param->pad_top_, param->split_stride_);
|
||||
borders.emplace_back(cur_border - param->pad_top_);
|
||||
cur_border = UP_ROUND(cur_border + param_->pad_top_, param_->split_stride_);
|
||||
borders.emplace_back(cur_border - param_->pad_top_);
|
||||
} else {
|
||||
borders.emplace_back(cur_border);
|
||||
}
|
||||
}
|
||||
borders.emplace_back(split_dim_size);
|
||||
|
||||
for (auto i = 0; i < param->num_split_; i++) {
|
||||
for (auto i = 0; i < param_->num_split_; i++) {
|
||||
start_indices_.emplace_back(borders[i]);
|
||||
end_indices_.emplace_back(borders[i + 1]);
|
||||
|
||||
// overlap: calibrate start_indices and end_indices by adding extends
|
||||
start_indices_[i] -= param->extend_top_[i];
|
||||
end_indices_[i] += param->extend_bottom_[i];
|
||||
start_indices_[i] -= param_->extend_top_[i];
|
||||
end_indices_[i] += param_->extend_bottom_[i];
|
||||
}
|
||||
}
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::Init() { return RET_OK; }
|
||||
int SplitWithOverlapBaseCPUKernel::Init() {
|
||||
MS_ASSERT(param_->num_split_ > 1);
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::ReSize() { return RET_OK; }
|
||||
int SplitWithOverlapBaseCPUKernel::ReSize() {
|
||||
auto in_tensor = in_tensors_.front();
|
||||
auto input_shape = in_tensor->shape();
|
||||
|
||||
start_indices_.clear();
|
||||
end_indices_.clear();
|
||||
|
||||
CalculateSplitedShapes(input_shape);
|
||||
|
||||
element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
|
||||
|
||||
outer_total_dim_ = 1;
|
||||
inner_stride_ = 1;
|
||||
|
||||
for (int i = 0; i < static_cast<int>(input_shape.size()); i++) {
|
||||
if (i < param_->split_dim_) outer_total_dim_ *= input_shape[i];
|
||||
if (i == param_->split_dim_) split_dim_size_ = input_shape[param_->split_dim_];
|
||||
if (i > param_->split_dim_) inner_stride_ *= input_shape[i];
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::Split(int task_id) {
|
||||
DoSplitWithOverlapParallel(input_ptr_, output_ptr_.data(), task_id, split_dim_size_, element_bytes_, outer_total_dim_,
|
||||
|
@ -84,39 +107,12 @@ int SplitWithOverlapRun(void *cdata, int task_id, float lhs_scale, float rhs_sca
|
|||
}
|
||||
|
||||
int SplitWithOverlapBaseCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail! ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
|
||||
auto in_tensor = in_tensors_.front();
|
||||
input_ptr_ = reinterpret_cast<char *>(in_tensor->data_c());
|
||||
auto input_shape = in_tensor->shape();
|
||||
|
||||
start_indices_.clear();
|
||||
end_indices_.clear();
|
||||
input_ptr_ = reinterpret_cast<char *>(in_tensors_.front()->data_c());
|
||||
output_ptr_.clear();
|
||||
|
||||
for (int i = 0; i < param_->num_split_; i++) {
|
||||
output_ptr_.push_back(reinterpret_cast<char *>(out_tensors_.at(i)->data_c()));
|
||||
}
|
||||
|
||||
CalculateSplitedShapes(param_, input_shape);
|
||||
|
||||
outer_total_dim_ = 1;
|
||||
inner_stride_ = 1;
|
||||
split_dim_size_ = input_shape[param_->split_dim_];
|
||||
element_bytes_ = static_cast<int>(lite::DataTypeSize(in_tensor->data_type()));
|
||||
|
||||
for (auto i = 0; i < param_->split_dim_; i++) {
|
||||
outer_total_dim_ *= input_shape[i];
|
||||
}
|
||||
|
||||
for (int i = static_cast<int>(input_shape.size()) - 1; i > param_->split_dim_; i--) {
|
||||
inner_stride_ *= input_shape[i];
|
||||
}
|
||||
|
||||
auto ret = ParallelLaunch(this->context_, SplitWithOverlapRun, this, param_->num_split_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParallelLaunch for SplitWIthOverlapRun run fail. errorcode:[" << ret << "]";
|
||||
|
@ -128,5 +124,4 @@ int SplitWithOverlapBaseCPUKernel::Run() {
|
|||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_SplitWithOverlap, LiteKernelCreator<SplitWithOverlapBaseCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -33,22 +33,25 @@ class SplitWithOverlapBaseCPUKernel : public InnerKernel {
|
|||
param_ = reinterpret_cast<SplitWithOverlapParameter *>(op_parameter_);
|
||||
}
|
||||
~SplitWithOverlapBaseCPUKernel() override = default;
|
||||
void CalculateSplitedShapes(const SplitWithOverlapParameter *param, const std::vector<int> &shape);
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int Split(int task_id);
|
||||
|
||||
protected:
|
||||
private:
|
||||
void CalculateSplitedShapes(const std::vector<int> &shape);
|
||||
|
||||
private:
|
||||
// range: [start, end)
|
||||
std::vector<int> start_indices_;
|
||||
std::vector<int> end_indices_;
|
||||
|
||||
SplitWithOverlapParameter *param_ = nullptr;
|
||||
|
||||
int outer_total_dim_{0};
|
||||
int inner_stride_{0};
|
||||
int element_bytes_{0};
|
||||
int split_dim_size_{0};
|
||||
SplitWithOverlapParameter *param_ = nullptr;
|
||||
char *input_ptr_{nullptr};
|
||||
std::vector<char *> output_ptr_;
|
||||
};
|
||||
|
|
|
@ -159,7 +159,7 @@ SearchSubGraph::CostModel SearchSubGraph::CalculateConv2DFusion(Model::Node *nod
|
|||
} else {
|
||||
int out_unit;
|
||||
if (CheckIfUseWinograd(&out_unit, param)) {
|
||||
size_t winograd_conv_cost = WinogradConvMul();
|
||||
size_t winograd_conv_cost = CommConvMul(weight_shape, output_shape);
|
||||
cost.mul_cost_ += winograd_conv_cost;
|
||||
} else {
|
||||
size_t comm_conv_mul_cost = CommConvMul(weight_shape, output_shape);
|
||||
|
@ -170,7 +170,7 @@ SearchSubGraph::CostModel SearchSubGraph::CalculateConv2DFusion(Model::Node *nod
|
|||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
if (CheckConvDw1DWinograd(param, context_->thread_num_)) {
|
||||
/* ConvolutionDepthwise3x3CPUKernel */
|
||||
size_t winograd_convdw_cost = WinogradConvDwMul();
|
||||
size_t winograd_convdw_cost = CommConvdwMul(weight_shape, output_shape);
|
||||
cost.mul_cost_ += winograd_convdw_cost;
|
||||
} else {
|
||||
/* ConvolutionDepthwiseIndirectCPUKernel */
|
||||
|
@ -450,18 +450,7 @@ void SearchSubGraph::OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint
|
|||
VectorErase(&sub.heads_, head_index);
|
||||
}
|
||||
|
||||
/* double check head-end node */
|
||||
/* head-end node may error after subgraph fusion */
|
||||
for (uint32_t head_node : sub.heads_) {
|
||||
if (std::find(sub.nodes_.begin(), sub.nodes_.end(), head_node) == sub.nodes_.end()) {
|
||||
VectorErase(&sub.nodes_, head_node);
|
||||
}
|
||||
}
|
||||
for (uint32_t end_node : sub.ends_) {
|
||||
if (std::find(sub.nodes_.begin(), sub.nodes_.end(), end_node) == sub.nodes_.end()) {
|
||||
VectorErase(&sub.ends_, end_node);
|
||||
}
|
||||
}
|
||||
CheckSubHeadEnd(&sub);
|
||||
|
||||
/* sort node index */
|
||||
std::sort(sub.nodes_.begin(), sub.nodes_.end());
|
||||
|
@ -579,9 +568,11 @@ void SearchSubGraph::InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes) {
|
|||
Model::Node *node = node_list_[node_index];
|
||||
for (uint32_t input_tensor_index : node->input_indices_) {
|
||||
Tensor *tensor = &tensors_[input_tensor_index];
|
||||
if (tensor->type_ == CONST) continue;
|
||||
if (tensor->type_ == CONST || tensor->type_ == INPUT) continue;
|
||||
|
||||
std::vector<uint32_t> input_nodes = tensor->out_nodes_;
|
||||
if (input_nodes.empty()) continue;
|
||||
|
||||
Subgraph sub;
|
||||
sub.ends_.push_back(input_nodes[0]);
|
||||
InsertNodeByMid(input_nodes[0], &sub);
|
||||
|
@ -602,6 +593,9 @@ void SearchSubGraph::InitSearchSubGraphByMiddle() {
|
|||
|
||||
InitMiddleSubgraph(&multy_in_nodes);
|
||||
|
||||
if (node_sub_map_.size() > kMaxSubGraphCount) {
|
||||
node_sub_map_.clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -763,6 +757,9 @@ void SearchSubGraph::SubGraphSplitByOutput() {
|
|||
CalculateCostModel(&sub_graphs_);
|
||||
InitSubgraphRuntimeInfo(&sub_graphs_);
|
||||
SubgraphFusion(&sub_graphs_);
|
||||
for (Subgraph &sub : sub_graphs_) {
|
||||
CheckSubHeadEnd(&sub);
|
||||
}
|
||||
ConvertSubGraphToModel(&sub_graphs_);
|
||||
}
|
||||
|
||||
|
@ -770,6 +767,9 @@ void SearchSubGraph::SubGraphSplitByMiddle() {
|
|||
InitSearchSubGraphByMiddle();
|
||||
for (auto map : node_sub_map_) {
|
||||
std::vector<Subgraph> &subgraphs = map.second;
|
||||
if (subgraphs.size() < kDefalutSubGraphSize) {
|
||||
continue;
|
||||
}
|
||||
|
||||
CalculateCostModel(&subgraphs);
|
||||
InitSubgraphRuntimeInfo(&subgraphs);
|
||||
|
@ -935,17 +935,63 @@ void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
|
|||
}
|
||||
}
|
||||
}
|
||||
void SearchSubGraph::CheckSubHeadEnd(Subgraph *sub) {
|
||||
/* head-end node may error after subgraph fusion */
|
||||
/* sub head node check */
|
||||
std::vector<uint32_t> delete_head;
|
||||
for (uint32_t head_node : sub->heads_) {
|
||||
if (std::find(sub->nodes_.begin(), sub->nodes_.end(), head_node) == sub->nodes_.end()) {
|
||||
delete_head.push_back(head_node);
|
||||
continue;
|
||||
}
|
||||
Model::Node *node = model_->all_nodes_.at(head_node);
|
||||
std::vector<uint32_t> in_tensors = node->input_indices_;
|
||||
std::vector<uint32_t> in_nodes;
|
||||
for (uint32_t in_t : in_tensors) {
|
||||
in_nodes.insert(in_nodes.begin(), tensors_.at(in_t).out_nodes_.begin(), tensors_.at(in_t).out_nodes_.end());
|
||||
}
|
||||
|
||||
void SearchSubGraph::InitSearchParallelSubGraph() {
|
||||
// for every graph output, find a parallel subgraph
|
||||
for (uint32_t output : *output_nodes_) {
|
||||
Subgraph subgraph;
|
||||
InsertParallelNode(output, &subgraph);
|
||||
sub_graphs_.push_back(std::move(subgraph));
|
||||
bool erase_head = true;
|
||||
for (uint32_t in_n : in_nodes) {
|
||||
if (std::find(sub->nodes_.begin(), sub->nodes_.end(), in_n) == sub->nodes_.end()) {
|
||||
erase_head = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (erase_head) {
|
||||
delete_head.push_back(head_node);
|
||||
}
|
||||
}
|
||||
for (uint32_t head : delete_head) {
|
||||
VectorErase(&sub->heads_, head);
|
||||
}
|
||||
|
||||
/* sub end node check */
|
||||
std::vector<uint32_t> delete_end;
|
||||
for (uint32_t end_node : sub->ends_) {
|
||||
if (std::find(sub->nodes_.begin(), sub->nodes_.end(), end_node) == sub->nodes_.end()) {
|
||||
delete_end.push_back(end_node);
|
||||
}
|
||||
}
|
||||
for (uint32_t end : delete_end) {
|
||||
VectorErase(&sub->ends_, end);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
bool SearchSubGraph::ValidInParallel() {
|
||||
Model::Node *front_node = model_->all_nodes_.at(0);
|
||||
if (front_node->quant_type_ != schema::QuantType_QUANT_NONE) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SearchSubGraph::SubGraphSplit() {
|
||||
if (!ValidInParallel()) {
|
||||
return;
|
||||
}
|
||||
|
||||
UpdateOfflineParallelFlag();
|
||||
if (offline_parallel_enable_) {
|
||||
SubGraphSplitByOffLineParallel();
|
||||
|
|
|
@ -33,6 +33,7 @@ namespace mindspore::lite {
|
|||
constexpr int kDefaultDeviceType = -1;
|
||||
constexpr int kDefalutSubGraphSize = 2;
|
||||
constexpr int kDefaultInputs = 1;
|
||||
constexpr int kMaxSubGraphCount = 20;
|
||||
class SearchSubGraph {
|
||||
enum TensorType { NORMAL, CONST, INPUT };
|
||||
|
||||
|
@ -84,12 +85,12 @@ class SearchSubGraph {
|
|||
public:
|
||||
void SubGraphSplit();
|
||||
|
||||
private:
|
||||
private: /* split by output */
|
||||
void SubGraphSplitByOutput();
|
||||
void InitSearchSubGraphByOutput();
|
||||
void InsertNode(uint32_t index, Subgraph *subgraph);
|
||||
|
||||
private:
|
||||
private: /* split by middle */
|
||||
void SubGraphSplitByMiddle();
|
||||
void InitSearchSubGraphByMiddle();
|
||||
void SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes);
|
||||
|
@ -97,35 +98,35 @@ class SearchSubGraph {
|
|||
void InsertNodeByMid(uint32_t node_index, Subgraph *subgraph);
|
||||
void InsertHeadNode(uint32_t index, Subgraph *subgraph);
|
||||
void OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index);
|
||||
std::unordered_map<uint32_t, std::vector<Subgraph>> node_sub_map_;
|
||||
|
||||
private:
|
||||
private: /* split by offline */
|
||||
void SubGraphSplitByOffLineParallel();
|
||||
void UpdateOfflineParallelFlag();
|
||||
bool CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs);
|
||||
|
||||
private:
|
||||
private: /* public graph func */
|
||||
void RemoveConstNode(std::vector<uint32_t> *nodes);
|
||||
void InitSearchTensor();
|
||||
void InitSearchParallelSubGraph();
|
||||
void InitMainGraphDevice(DeviceType dt = DT_CPU);
|
||||
|
||||
void InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs);
|
||||
void SubgraphFusion(std::vector<Subgraph> *sub_graphs);
|
||||
void CalculateCostModel(std::vector<Subgraph> *sub_graphs);
|
||||
void ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs);
|
||||
bool ValidInParallel();
|
||||
void CheckSubHeadEnd(Subgraph *sub);
|
||||
|
||||
private:
|
||||
private: /* public schema func */
|
||||
void InsertParallelNode(uint32_t index, Subgraph *subgraph);
|
||||
bool IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes);
|
||||
bool IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
|
||||
uint32_t root_node_index);
|
||||
const schema::Primitive *CreatePartialPrimitive(int64_t subgraph_index);
|
||||
|
||||
private:
|
||||
private: /* public cost-model func */
|
||||
CostModel CalculateConv2DFusion(Model::Node *node);
|
||||
void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
|
||||
std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs);
|
||||
void UpdateOfflineParallelFlag();
|
||||
bool CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs);
|
||||
|
||||
private:
|
||||
std::vector<size_t> *output_nodes_ = nullptr;
|
||||
|
@ -135,6 +136,7 @@ class SearchSubGraph {
|
|||
LiteModel *model_ = nullptr;
|
||||
std::vector<Tensor> tensors_;
|
||||
std::vector<Subgraph> sub_graphs_;
|
||||
std::unordered_map<uint32_t, std::vector<Subgraph>> node_sub_map_;
|
||||
std::vector<Model::Node *> node_list_;
|
||||
DeviceType major_dt_;
|
||||
DeviceType minor_dt_;
|
||||
|
|
Loading…
Reference in New Issue