forked from mindspore-Ecosystem/mindspore
clean magic number and unsuitable input argument
This commit is contained in:
parent
9720bab9c9
commit
4dd4d3989a
|
@ -195,6 +195,9 @@ int PackAttentionBias(Matrix *matrix, int tile) {
|
|||
int size = matrix->col_;
|
||||
float *src = matrix->data_;
|
||||
int size_align = UP_ROUND(size, tile);
|
||||
if (size_align <= 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
matrix->packed_data_ = (float *)malloc(size_align * sizeof(float));
|
||||
if (matrix->packed_data_ == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
|
@ -287,21 +290,28 @@ static void ElementOptAddDiv(const float *input0, const float *input1, const flo
|
|||
}
|
||||
}
|
||||
|
||||
static void GetTransposeParameter(TransposeParameter *param, const int in_shape[], const int out_shape[],
|
||||
const int perm[]) {
|
||||
param->num_axes_ = 4;
|
||||
static bool GetTransposeParameter(TransposeParameter *param, const int in_shape[], int in_shape_len,
|
||||
const int out_shape[], int out_shape_len, const int perm[], int perm_len) {
|
||||
param->num_axes_ = perm_len;
|
||||
size_t shape_size = 1;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int i = 0; i < perm_len; i++) {
|
||||
param->perm_[i] = perm[i];
|
||||
shape_size *= perm[i]; // check overflow
|
||||
}
|
||||
param->data_num_ = (int)shape_size; // check overflow
|
||||
param->strides_[param->num_axes_ - 1] = 1;
|
||||
param->out_strides_[param->num_axes_ - 1] = 1;
|
||||
if (param->num_axes_ - 1 >= in_shape_len) {
|
||||
return false;
|
||||
}
|
||||
if (param->num_axes_ - 1 >= out_shape_len) {
|
||||
return false;
|
||||
}
|
||||
for (int i = param->num_axes_ - 2; i >= 0; i--) {
|
||||
param->strides_[i] = in_shape[i + 1] * param->strides_[i + 1];
|
||||
param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, Matrix *wq_mat, Matrix *bq_mat,
|
||||
|
@ -329,8 +339,8 @@ void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, Mat
|
|||
int q_with_pos_trans_in_shape[] = {batch, param->q_seq_, num_heads, depth};
|
||||
int q_with_pos_trans_out_shape[] = {batch, num_heads, param->q_seq_, depth};
|
||||
int q_with_pos_perm[] = {0, 2, 1, 3};
|
||||
GetTransposeParameter(&q_with_pos_trans_param, q_with_pos_trans_in_shape, q_with_pos_trans_out_shape,
|
||||
q_with_pos_perm);
|
||||
(void)GetTransposeParameter(&q_with_pos_trans_param, q_with_pos_trans_in_shape, 4, q_with_pos_trans_out_shape, 4,
|
||||
q_with_pos_perm, 4);
|
||||
int q2wq_reshaped_area = q2wq_mat->row_ * q2wq_mat->col_;
|
||||
// Q_WQ + POS_U
|
||||
{
|
||||
|
@ -396,7 +406,7 @@ void KMulWeightK(RelativePositionAttentionParameter *param, Matrix *k_mat, Matri
|
|||
int k2wk_in_shape[] = {batch, param->k_seq_, num_heads, depth};
|
||||
int k2wk_out_shape[] = {batch, num_heads, depth, param->k_seq_};
|
||||
int k2wk_perm[] = {0, 2, 3, 1};
|
||||
GetTransposeParameter(&k2wk_trans_param, k2wk_in_shape, k2wk_out_shape, k2wk_perm);
|
||||
(void)GetTransposeParameter(&k2wk_trans_param, k2wk_in_shape, 4, k2wk_out_shape, 4, k2wk_perm, 4);
|
||||
TransposeDimsFp32(k2wk, k2wk_trans_data, k2wk_out_shape, &k2wk_trans_param, 0, 1);
|
||||
}
|
||||
|
||||
|
@ -427,7 +437,7 @@ void VMulWeightV(RelativePositionAttentionParameter *param, Matrix *v_mat, Matri
|
|||
int v2wv_in_shape[] = {batch, param->v_seq_, num_heads, depth};
|
||||
int v2wv_out_shape[] = {batch, num_heads, param->v_seq_, depth};
|
||||
int v2wv_perm[] = {0, 2, 1, 3};
|
||||
GetTransposeParameter(&v2wv_trans_param, v2wv_in_shape, v2wv_out_shape, v2wv_perm);
|
||||
(void)GetTransposeParameter(&v2wv_trans_param, v2wv_in_shape, 4, v2wv_out_shape, 4, v2wv_perm, 4);
|
||||
TransposeDimsFp32(v2wv, v2wv_trans_data, v2wv_out_shape, &v2wv_trans_param, 0, 1);
|
||||
}
|
||||
|
||||
|
@ -459,7 +469,7 @@ void PMulWeightP(RelativePositionAttentionParameter *param, Matrix *p_mat, Matri
|
|||
int p2wp_in_shape[] = {batch, param->p_seq_, num_heads, depth};
|
||||
int p2wp_out_shape[] = {batch, num_heads, depth, param->p_seq_};
|
||||
int p2wp_perm[] = {0, 2, 3, 1};
|
||||
GetTransposeParameter(&p2wp_trans_param, p2wp_in_shape, p2wp_out_shape, p2wp_perm);
|
||||
(void)GetTransposeParameter(&p2wp_trans_param, p2wp_in_shape, 4, p2wp_out_shape, 4, p2wp_perm, 4);
|
||||
TransposeDimsFp32(p2wp_data, p2wp_trans_data, p2wp_out_shape, &p2wp_trans_param, 0, 1);
|
||||
}
|
||||
|
||||
|
@ -534,7 +544,8 @@ void RelPosAttention(RelativePositionAttentionParameter *param, Matrix *logits_m
|
|||
int logits2v_trans_in_shape[] = {batch, num_heads, param->q_seq_, depth};
|
||||
int logits2v_trans_out_shape[] = {batch, param->q_seq_, num_heads, depth};
|
||||
int logits2v_trans_perm[] = {0, 2, 1, 3};
|
||||
GetTransposeParameter(&logits2v_trans_param, logits2v_trans_in_shape, logits2v_trans_out_shape, logits2v_trans_perm);
|
||||
(void)GetTransposeParameter(&logits2v_trans_param, logits2v_trans_in_shape, 4, logits2v_trans_out_shape, 4,
|
||||
logits2v_trans_perm, 4);
|
||||
TransposeDimsFp32(logits2v_data, logits2v_trans_data, logits2v_trans_out_shape, &logits2v_trans_param, 0, 1);
|
||||
// concat = reshape [batch, -1, d_model]
|
||||
logits2v_trans_mat->batch_ = batch;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_NNACL_NNACL_COMMON_H_
|
||||
#define MINDSPORE_NNACL_NNACL_COMMON_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -35,6 +35,7 @@ constexpr int kActivationTensorBatch = 1;
|
|||
constexpr int kTensorShapeBatchIndex = 0;
|
||||
constexpr int k3DimsLeftMatrixDeepIndex = 2;
|
||||
constexpr int kRightMatrixDeepIndex = 0;
|
||||
constexpr int kRelativePositionHasBiasInputSize = 15;
|
||||
|
||||
bool AttentionActivationTensorCheck(lite::Tensor *tensor) {
|
||||
if (tensor == nullptr || tensor->data_type() != kNumberTypeFloat32 ||
|
||||
|
@ -161,7 +162,7 @@ bool AttentionBiasTensorCheck(lite::Tensor *tensor) {
|
|||
} // namespace
|
||||
|
||||
int RelativePositionAttentionCPUKernel::CheckBiases() {
|
||||
if (this->in_tensors_.size() == 15) {
|
||||
if (this->in_tensors_.size() == kRelativePositionHasBiasInputSize) {
|
||||
param_->use_bias_ = true;
|
||||
}
|
||||
if (!param_->use_bias_) {
|
||||
|
@ -252,6 +253,9 @@ int RelativePositionAttentionCPUKernel::PrepareParam() {
|
|||
}
|
||||
|
||||
namespace {
|
||||
constexpr int kLeftMatrixBatchDimIndex = 0;
|
||||
constexpr int kLeftMatrixRowDimIndex = 1;
|
||||
constexpr int kLeftMatrixColDimIndex = 2;
|
||||
inline int PackLeftTensor(const lite::Tensor &tensor, Matrix *matrix, int row_tile, const AllocatorPtr &allocator) {
|
||||
MS_ASSERT(matrix != nullptr);
|
||||
MS_ASSERT(allocator != nullptr);
|
||||
|
@ -259,9 +263,9 @@ inline int PackLeftTensor(const lite::Tensor &tensor, Matrix *matrix, int row_ti
|
|||
matrix->data_ = reinterpret_cast<float *>(tensor.data_c());
|
||||
matrix->is_transpose_ = false;
|
||||
// Left tensor is in [batch, row, col] shape
|
||||
matrix->batch_ = tensor.shape().at(0);
|
||||
matrix->row_ = tensor.shape().at(1);
|
||||
matrix->col_ = tensor.shape().at(2);
|
||||
matrix->batch_ = tensor.shape().at(kLeftMatrixBatchDimIndex);
|
||||
matrix->row_ = tensor.shape().at(kLeftMatrixRowDimIndex);
|
||||
matrix->col_ = tensor.shape().at(kLeftMatrixColDimIndex);
|
||||
auto size = LeftMatrixPackElementSize(matrix, row_tile) * sizeof(float);
|
||||
MS_ASSERT(size != 0);
|
||||
matrix->packed_data_ = reinterpret_cast<float *>(allocator->Malloc(size));
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
# add shared link library
|
||||
include_directories(${CCSRC_DIR}/backend/kernel_compiler/cpu)
|
||||
|
||||
set(COMMON_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../ccsrc/backend/kernel_compiler/cpu/nnacl/nnacl_common.c
|
||||
)
|
||||
|
||||
add_executable(benchmark
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#define __STDC_FORMAT_MACROS
|
||||
#include <cinttypes>
|
||||
#undef __STDC_FORMAT_MACROS
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include "include/context.h"
|
||||
|
@ -27,6 +26,7 @@
|
|||
#include "schema/model_generated.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/tensor.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#include <linux/perf_event.h>
|
||||
#include <sys/ioctl.h>
|
||||
|
@ -450,6 +450,10 @@ int Benchmark::PrintInputData() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr float kNumUsPerMs = 1000.;
|
||||
}
|
||||
|
||||
int Benchmark::RunBenchmark() {
|
||||
auto start_prepare_time = GetTimeUs();
|
||||
// Load graph
|
||||
|
@ -506,8 +510,9 @@ int Benchmark::RunBenchmark() {
|
|||
|
||||
ms_inputs_ = session_->GetInputs();
|
||||
auto end_prepare_time = GetTimeUs();
|
||||
MS_LOG(INFO) << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms";
|
||||
std::cout << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms" << std::endl;
|
||||
MS_LOG(INFO) << "PrepareTime = " << static_cast<float>(end_prepare_time - start_prepare_time) / kNumUsPerMs << " ms";
|
||||
std::cout << "PrepareTime = " << static_cast<float>(end_prepare_time - start_prepare_time) / kNumUsPerMs << " ms"
|
||||
<< std::endl;
|
||||
|
||||
// Load input
|
||||
MS_LOG(INFO) << "start generate input data";
|
||||
|
@ -580,7 +585,7 @@ int Benchmark::InitTimeProfilingCallbackParameter() {
|
|||
MS_LOG(INFO) << "The num of after outputs is empty";
|
||||
}
|
||||
|
||||
float cost = static_cast<float>(opEnd - op_begin_) / 1000.0f;
|
||||
float cost = static_cast<float>(opEnd - op_begin_) / kNumUsPerMs;
|
||||
if (flags_->device_ == "GPU") {
|
||||
auto gpu_param = reinterpret_cast<const GPUCallBackParam &>(call_param);
|
||||
cost = static_cast<float>(gpu_param.execute_time);
|
||||
|
|
|
@ -725,6 +725,10 @@ int NetTrain::Init() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr int kNumToPrint = 5;
|
||||
}
|
||||
|
||||
int NetTrain::PrintResult(const std::vector<std::string> &title,
|
||||
const std::map<std::string, std::pair<int, float>> &result) {
|
||||
std::vector<size_t> columnLenMax(kFieldsToPrint);
|
||||
|
@ -773,7 +777,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
|
|||
}
|
||||
|
||||
printf("-------------------------------------------------------------------------\n");
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int i = 0; i < kNumToPrint; i++) {
|
||||
auto printBuf = title[i];
|
||||
if (printBuf.size() > columnLenMax.at(i)) {
|
||||
columnLenMax.at(i) = printBuf.size();
|
||||
|
@ -783,7 +787,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
|
|||
}
|
||||
printf("\n");
|
||||
for (size_t i = 0; i < rows.size(); i++) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
for (int j = 0; j < kNumToPrint; j++) {
|
||||
auto printBuf = rows[i][j];
|
||||
printBuf.resize(columnLenMax.at(j), ' ');
|
||||
printf("%s\t", printBuf.c_str());
|
||||
|
|
|
@ -698,16 +698,21 @@ std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNode
|
|||
}
|
||||
return perm;
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr size_t kBitNumPerByte = 8;
|
||||
}
|
||||
|
||||
std::string BoolVectorToString(const std::vector<bool> &bool_vec) {
|
||||
size_t size_in_byte = ceil(bool_vec.size() / 8.0);
|
||||
size_t size_in_byte = ceil(bool_vec.size() / kBitNumPerByte);
|
||||
std::string str(size_in_byte, '\0');
|
||||
auto iter = str.begin();
|
||||
size_t shift = 8;
|
||||
size_t shift = kBitNumPerByte;
|
||||
for (bool bit : bool_vec) {
|
||||
*iter |= bit << (shift - 1);
|
||||
if (--shift == 0) {
|
||||
iter++;
|
||||
shift = 8;
|
||||
shift = kBitNumPerByte;
|
||||
}
|
||||
}
|
||||
return str;
|
||||
|
|
Loading…
Reference in New Issue