[MS][LITE] clean codex

This commit is contained in:
cjh9368 2021-08-03 19:00:30 +08:00
parent 93af13f332
commit 98abc333de
13 changed files with 34 additions and 25 deletions

View File

@ -158,5 +158,4 @@ void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const s
}
net.close();
}
} // namespace mindspore::lite::micro

View File

@ -147,15 +147,16 @@ void Conv2DInt8Coder::CheckSupportOptimize() {
}
int Conv2DInt8Coder::InitTmpBuffer() {
const size_t kPartial = 2;
switch (opt_) {
case Basic:
buffer_size_ =
static_cast<size_t>(2 * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
static_cast<size_t>(kPartial * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
sizeof(int16_t);
break;
case Convolve_1_x_n:
buffer_size_ =
static_cast<size_t>(2 * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
static_cast<size_t>(kPartial * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
sizeof(int16_t);
break;
case Convolve_1x1_fast:

View File

@ -83,5 +83,4 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) {
context->AppendCode(code.str());
return RET_OK;
}
} // namespace mindspore::lite::micro::nnacl

View File

@ -30,6 +30,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_GLU;
namespace mindspore::kernel {
const int kGluBranchNum = 2;
int GluCPUKernel::MallocTmpBuffer() {
FreeTmpBuffer();
auto in_tensor = in_tensors_.front();
@ -115,7 +116,7 @@ int GluCPUKernel::Split(int task_id) {
int GluCPUKernel::Sigmoid(int task_id) {
auto input_addr = reinterpret_cast<float *>(split_ptr_.at(1));
auto output_addr = reinterpret_cast<float *>(sigmoid_ptr_);
auto length = in_tensors_.at(0)->ElementsNum() / 2;
auto length = in_tensors_.at(0)->ElementsNum() / kGluBranchNum;
int stride = UP_DIV(length, op_parameter_->thread_num_);
int count = MSMIN(stride, length - stride * task_id);
@ -129,7 +130,7 @@ int GluCPUKernel::Mul(int task_id) {
auto input_addr0 = reinterpret_cast<float *>(split_ptr_.at(0));
auto input_addr1 = reinterpret_cast<float *>(sigmoid_ptr_);
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
auto length = in_tensors_.at(0)->ElementsNum() / 2;
auto length = in_tensors_.at(0)->ElementsNum() / kGluBranchNum;
int stride = UP_DIV(length, op_parameter_->thread_num_);
int count = MSMIN(stride, length - stride * task_id);

View File

@ -463,6 +463,5 @@ size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_
return 1;
}
}
} // namespace lite
} // namespace mindspore

View File

@ -19,14 +19,19 @@
#include "mindspore/core/utils/log_adapter.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore::lite {
const size_t kWeightQueryIndex = 4;
const size_t kWeightKeyIndex = 5;
const size_t kWeightValueIndex = 6;
const size_t kWeightOutputIndex = 10;
bool AttentionQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
mindspore::schema::CNodeT *node) {
MS_ASSERT(node->inputIndex.size() >= 2);
auto &input_tensor = graph.allTensors.at(node->inputIndex.at(kInputIndex));
auto &weight_query_tensor = graph.allTensors.at(node->inputIndex.at(4));
auto &weight_key_tensor = graph.allTensors.at(node->inputIndex.at(5));
auto &weight_value_tensor = graph.allTensors.at(node->inputIndex.at(6));
auto &weight_output_tensor = graph.allTensors.at(node->inputIndex.at(10));
auto &weight_query_tensor = graph.allTensors.at(node->inputIndex.at(kWeightQueryIndex));
auto &weight_key_tensor = graph.allTensors.at(node->inputIndex.at(kWeightKeyIndex));
auto &weight_value_tensor = graph.allTensors.at(node->inputIndex.at(kWeightValueIndex));
auto &weight_output_tensor = graph.allTensors.at(node->inputIndex.at(kWeightOutputIndex));
if (!quant::TensorQuantParamsInited(*input_tensor) && quant::TensorQuantParamsInited(*weight_query_tensor) &&
quant::TensorQuantParamsInited(*weight_key_tensor) && quant::TensorQuantParamsInited(*weight_value_tensor) &&

View File

@ -58,5 +58,4 @@ STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGra
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -16,7 +16,6 @@
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
namespace mindspore::lite {
bool DefaultQuantAllQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
return true;
}

View File

@ -16,7 +16,6 @@
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
namespace mindspore::lite {
bool OnlyNeedInputsQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
UpdateQuantParamsNum(graph, *node);
if (input_inited_quant_params_ == node->inputIndex.size()) {

View File

@ -142,5 +142,4 @@ QuantHelperRegister::~QuantHelperRegister() {
}
this->register_map_.clear();
}
} // namespace mindspore::lite

View File

@ -148,7 +148,6 @@ std::shared_ptr<ops::MatMul> BuildMatMulPrim(const CNodePtr &stack_cnode) {
matmul_cvalue->AddAttr("quant_params", quant_params_holder);
return matmul_cvalue;
}
} // namespace
const BaseRef BatchMatMulFusion::DefinePattern() const {
auto pack_var = std::make_shared<CondVar>(IsStackNode);

View File

@ -21,6 +21,7 @@
namespace mindspore::opt {
namespace {
const auto &p1 = std::placeholders::_1;
const size_t kWeightShapeSize = 2;
} // namespace
MultiHeadAttentionFusion::MultiHeadAttentionFusion(const string &name, bool multigraph)
@ -244,7 +245,8 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
MS_LOG(ERROR) << "Get reshape k data failed";
return nullptr;
}
if (shape_k.size() < 2 || shape_v.size() < 2 || shape_k.at(shape_k.size() - 2) != shape_v.at(shape_v.size() - 2)) {
if (shape_k.size() < kWeightShapeSize || shape_v.size() < kWeightShapeSize ||
shape_k.at(shape_k.size() - kWeightShapeSize) != shape_v.at(shape_v.size() - kWeightShapeSize)) {
MS_LOG(ERROR) << "Shape k or shape v is invalid.";
return nullptr;
}

View File

@ -23,6 +23,14 @@
namespace mindspore::opt {
namespace {
const auto &p1 = std::placeholders::_1;
const size_t kWeightQueryIndex = 4;
const size_t kWeightKeyIndex = 5;
const size_t kWeightValueIndex = 6;
const size_t kWeightPosIndex = 7;
const size_t kWeightOutputIndex = 10;
const size_t kStackParamSize = 2;
const size_t kInputSize = 16;
const size_t kOutputSize = 2;
} // namespace
TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const string &name, bool multigraph)
@ -37,7 +45,7 @@ TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const
output_prim_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFullConnection));
pos_prim_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFullConnection));
for (size_t i = 0; i < 2; i++) {
for (size_t i = 0; i < kStackParamSize; i++) {
query_stack_params_.emplace_back(std::make_shared<Var>());
key_stack_params_.emplace_back(std::make_shared<Var>());
value_stack_params_.emplace_back(std::make_shared<Var>());
@ -157,38 +165,38 @@ CNodePtr TfliteRelPosMultiHeadAttentionFusion::CreateRelPosMultiHeadAttentionNod
MS_LOG(ERROR) << "Build attention primitive failed.";
return nullptr;
}
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(16, 1);
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSize, kOutputSize);
auto query_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[query_prim_]));
auto query_quant_param_holder = query_prim->GetAttr("quant_params");
if (query_quant_param_holder != nullptr) {
quant_params_holder->set_input_quant_param(
4, query_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
kWeightQueryIndex, query_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
}
auto key_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[key_prim_]));
auto key_quant_param_holder = key_prim->GetAttr("quant_params");
if (key_quant_param_holder != nullptr) {
quant_params_holder->set_input_quant_param(
5, key_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
kWeightKeyIndex, key_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
}
auto value_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[value_prim_]));
auto value_quant_param_holder = value_prim->GetAttr("quant_params");
if (value_quant_param_holder != nullptr) {
quant_params_holder->set_input_quant_param(
6, value_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
kWeightValueIndex, value_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
}
auto pos_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[pos_prim_]));
auto pos_quant_param_holder = pos_prim->GetAttr("quant_params");
if (pos_quant_param_holder != nullptr) {
quant_params_holder->set_input_quant_param(
7, pos_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
kWeightPosIndex, pos_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
}
auto output_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[output_prim_]));
auto output_quant_param_holder = output_prim->GetAttr("quant_params");
if (output_quant_param_holder != nullptr) {
quant_params_holder->set_input_quant_param(
10, output_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
kWeightOutputIndex, output_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
}
attention_prim->AddAttr("quant_params", quant_params_holder);
@ -273,7 +281,7 @@ const VectorRef TfliteRelPosMultiHeadAttentionFusion::DefineProcessInputPattern(
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion)), result, bias});
}
MS_ASSERT(stack_params.size() == 2);
MS_ASSERT(stack_params.size() == kStackParamSize);
auto stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStack)), std::make_shared<Var>(),
std::make_shared<Var>(), stack_params.at(0), stack_params.at(1)});
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), result, stack});