forked from mindspore-Ecosystem/mindspore
[MS][LITE] clean codex
This commit is contained in:
parent
93af13f332
commit
98abc333de
|
@ -158,5 +158,4 @@ void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const s
|
|||
}
|
||||
net.close();
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -147,15 +147,16 @@ void Conv2DInt8Coder::CheckSupportOptimize() {
|
|||
}
|
||||
|
||||
int Conv2DInt8Coder::InitTmpBuffer() {
|
||||
const size_t kPartial = 2;
|
||||
switch (opt_) {
|
||||
case Basic:
|
||||
buffer_size_ =
|
||||
static_cast<size_t>(2 * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
|
||||
static_cast<size_t>(kPartial * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
|
||||
sizeof(int16_t);
|
||||
break;
|
||||
case Convolve_1_x_n:
|
||||
buffer_size_ =
|
||||
static_cast<size_t>(2 * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
|
||||
static_cast<size_t>(kPartial * input_tensor_->Channel() * filter_tensor_->Width() * filter_tensor_->Height()) *
|
||||
sizeof(int16_t);
|
||||
break;
|
||||
case Convolve_1x1_fast:
|
||||
|
|
|
@ -83,5 +83,4 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) {
|
|||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
|
|
@ -30,6 +30,7 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_GLU;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
const int kGluBranchNum = 2;
|
||||
int GluCPUKernel::MallocTmpBuffer() {
|
||||
FreeTmpBuffer();
|
||||
auto in_tensor = in_tensors_.front();
|
||||
|
@ -115,7 +116,7 @@ int GluCPUKernel::Split(int task_id) {
|
|||
int GluCPUKernel::Sigmoid(int task_id) {
|
||||
auto input_addr = reinterpret_cast<float *>(split_ptr_.at(1));
|
||||
auto output_addr = reinterpret_cast<float *>(sigmoid_ptr_);
|
||||
auto length = in_tensors_.at(0)->ElementsNum() / 2;
|
||||
auto length = in_tensors_.at(0)->ElementsNum() / kGluBranchNum;
|
||||
|
||||
int stride = UP_DIV(length, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
@ -129,7 +130,7 @@ int GluCPUKernel::Mul(int task_id) {
|
|||
auto input_addr0 = reinterpret_cast<float *>(split_ptr_.at(0));
|
||||
auto input_addr1 = reinterpret_cast<float *>(sigmoid_ptr_);
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
auto length = in_tensors_.at(0)->ElementsNum() / 2;
|
||||
auto length = in_tensors_.at(0)->ElementsNum() / kGluBranchNum;
|
||||
|
||||
int stride = UP_DIV(length, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
|
|
@ -463,6 +463,5 @@ size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_
|
|||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,14 +19,19 @@
|
|||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
namespace mindspore::lite {
|
||||
const size_t kWeightQueryIndex = 4;
|
||||
const size_t kWeightKeyIndex = 5;
|
||||
const size_t kWeightValueIndex = 6;
|
||||
const size_t kWeightOutputIndex = 10;
|
||||
|
||||
bool AttentionQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
|
||||
mindspore::schema::CNodeT *node) {
|
||||
MS_ASSERT(node->inputIndex.size() >= 2);
|
||||
auto &input_tensor = graph.allTensors.at(node->inputIndex.at(kInputIndex));
|
||||
auto &weight_query_tensor = graph.allTensors.at(node->inputIndex.at(4));
|
||||
auto &weight_key_tensor = graph.allTensors.at(node->inputIndex.at(5));
|
||||
auto &weight_value_tensor = graph.allTensors.at(node->inputIndex.at(6));
|
||||
auto &weight_output_tensor = graph.allTensors.at(node->inputIndex.at(10));
|
||||
auto &weight_query_tensor = graph.allTensors.at(node->inputIndex.at(kWeightQueryIndex));
|
||||
auto &weight_key_tensor = graph.allTensors.at(node->inputIndex.at(kWeightKeyIndex));
|
||||
auto &weight_value_tensor = graph.allTensors.at(node->inputIndex.at(kWeightValueIndex));
|
||||
auto &weight_output_tensor = graph.allTensors.at(node->inputIndex.at(kWeightOutputIndex));
|
||||
|
||||
if (!quant::TensorQuantParamsInited(*input_tensor) && quant::TensorQuantParamsInited(*weight_query_tensor) &&
|
||||
quant::TensorQuantParamsInited(*weight_key_tensor) && quant::TensorQuantParamsInited(*weight_value_tensor) &&
|
||||
|
|
|
@ -58,5 +58,4 @@ STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGra
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
bool DefaultQuantAllQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
bool OnlyNeedInputsQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
UpdateQuantParamsNum(graph, *node);
|
||||
if (input_inited_quant_params_ == node->inputIndex.size()) {
|
||||
|
|
|
@ -142,5 +142,4 @@ QuantHelperRegister::~QuantHelperRegister() {
|
|||
}
|
||||
this->register_map_.clear();
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -148,7 +148,6 @@ std::shared_ptr<ops::MatMul> BuildMatMulPrim(const CNodePtr &stack_cnode) {
|
|||
matmul_cvalue->AddAttr("quant_params", quant_params_holder);
|
||||
return matmul_cvalue;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
const BaseRef BatchMatMulFusion::DefinePattern() const {
|
||||
auto pack_var = std::make_shared<CondVar>(IsStackNode);
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
const size_t kWeightShapeSize = 2;
|
||||
} // namespace
|
||||
|
||||
MultiHeadAttentionFusion::MultiHeadAttentionFusion(const string &name, bool multigraph)
|
||||
|
@ -244,7 +245,8 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
|
|||
MS_LOG(ERROR) << "Get reshape k data failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (shape_k.size() < 2 || shape_v.size() < 2 || shape_k.at(shape_k.size() - 2) != shape_v.at(shape_v.size() - 2)) {
|
||||
if (shape_k.size() < kWeightShapeSize || shape_v.size() < kWeightShapeSize ||
|
||||
shape_k.at(shape_k.size() - kWeightShapeSize) != shape_v.at(shape_v.size() - kWeightShapeSize)) {
|
||||
MS_LOG(ERROR) << "Shape k or shape v is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,14 @@
|
|||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
const size_t kWeightQueryIndex = 4;
|
||||
const size_t kWeightKeyIndex = 5;
|
||||
const size_t kWeightValueIndex = 6;
|
||||
const size_t kWeightPosIndex = 7;
|
||||
const size_t kWeightOutputIndex = 10;
|
||||
const size_t kStackParamSize = 2;
|
||||
const size_t kInputSize = 16;
|
||||
const size_t kOutputSize = 2;
|
||||
} // namespace
|
||||
|
||||
TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const string &name, bool multigraph)
|
||||
|
@ -37,7 +45,7 @@ TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const
|
|||
output_prim_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFullConnection));
|
||||
pos_prim_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFullConnection));
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t i = 0; i < kStackParamSize; i++) {
|
||||
query_stack_params_.emplace_back(std::make_shared<Var>());
|
||||
key_stack_params_.emplace_back(std::make_shared<Var>());
|
||||
value_stack_params_.emplace_back(std::make_shared<Var>());
|
||||
|
@ -157,38 +165,38 @@ CNodePtr TfliteRelPosMultiHeadAttentionFusion::CreateRelPosMultiHeadAttentionNod
|
|||
MS_LOG(ERROR) << "Build attention primitive failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(16, 1);
|
||||
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(kInputSize, kOutputSize);
|
||||
auto query_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[query_prim_]));
|
||||
auto query_quant_param_holder = query_prim->GetAttr("quant_params");
|
||||
if (query_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
4, query_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
kWeightQueryIndex, query_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
auto key_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[key_prim_]));
|
||||
auto key_quant_param_holder = key_prim->GetAttr("quant_params");
|
||||
if (key_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
5, key_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
kWeightKeyIndex, key_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
auto value_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[value_prim_]));
|
||||
auto value_quant_param_holder = value_prim->GetAttr("quant_params");
|
||||
if (value_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
6, value_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
kWeightValueIndex, value_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
|
||||
auto pos_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[pos_prim_]));
|
||||
auto pos_quant_param_holder = pos_prim->GetAttr("quant_params");
|
||||
if (pos_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
7, pos_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
kWeightPosIndex, pos_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
|
||||
auto output_prim = GetValueNode<PrimitivePtr>(utils::cast<AnfNodePtr>((*equiv)[output_prim_]));
|
||||
auto output_quant_param_holder = output_prim->GetAttr("quant_params");
|
||||
if (output_quant_param_holder != nullptr) {
|
||||
quant_params_holder->set_input_quant_param(
|
||||
10, output_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
kWeightOutputIndex, output_quant_param_holder->cast<lite::QuantParamHolderPtr>()->get_input_quant_params().at(1));
|
||||
}
|
||||
|
||||
attention_prim->AddAttr("quant_params", quant_params_holder);
|
||||
|
@ -273,7 +281,7 @@ const VectorRef TfliteRelPosMultiHeadAttentionFusion::DefineProcessInputPattern(
|
|||
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion)), result, bias});
|
||||
}
|
||||
|
||||
MS_ASSERT(stack_params.size() == 2);
|
||||
MS_ASSERT(stack_params.size() == kStackParamSize);
|
||||
auto stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStack)), std::make_shared<Var>(),
|
||||
std::make_shared<Var>(), stack_params.at(0), stack_params.at(1)});
|
||||
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), result, stack});
|
||||
|
|
Loading…
Reference in New Issue