forked from mindspore-Ecosystem/mindspore
!7165 [MSLITE][Develop] optimize arm fp32 cpu op lstm: add neon
Merge pull request !7165 from yangruoqi713/lite
This commit is contained in:
commit
478200d2fe
|
@ -37,16 +37,44 @@ void MatMulAcc(float *output, const float *input, const float *weight, int rows,
|
|||
for (int r = 0; r < rows; r++) {
|
||||
for (int c = 0; c < cols; c++) {
|
||||
float res = 0;
|
||||
for (int i = 0; i < inner_size; i++) {
|
||||
res += input[r * inner_size + i] * weight[c * inner_size + i];
|
||||
const float *input_col = input + r * inner_size;
|
||||
const float *weight_col = weight + c * inner_size;
|
||||
int index = 0;
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t out = vdupq_n_f32(0.0f);
|
||||
for (; index < inner_size - 4; index += 4) {
|
||||
float32x4_t in_0 = vld1q_f32(input_col + index);
|
||||
float32x4_t in_1 = vld1q_f32(weight_col + index);
|
||||
out = vmlaq_f32(out, in_1, in_0);
|
||||
}
|
||||
#ifdef ENABLE_ARM64
|
||||
res += vaddvq_f32(out);
|
||||
#else
|
||||
float32x2_t add2 = vadd_f32(vget_low_f32(out), vget_high_f32(out));
|
||||
float32x2_t add4 = vpadd_f32(add2, add2);
|
||||
res += vget_lane_f32(add4, 0);
|
||||
#endif
|
||||
#endif
|
||||
for (; index < inner_size; index++) {
|
||||
res += input_col[index] * weight_col[index];
|
||||
}
|
||||
output[r * cols + c] += res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ElementMulAcc(float *input0, float *input1, float *output, int element_size) {
|
||||
for (int index = 0; index < element_size; index++) {
|
||||
void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_ARM
|
||||
for (; index < element_size - 4; index += 4) {
|
||||
float32x4_t in_0 = vld1q_f32(input0 + index);
|
||||
float32x4_t in_1 = vld1q_f32(input1 + index);
|
||||
float32x4_t out = vld1q_f32(output + index);
|
||||
out = vmlaq_f32(out, in_1, in_0);
|
||||
vst1q_f32(output + index, out);
|
||||
}
|
||||
#endif
|
||||
for (; index < element_size; index++) {
|
||||
output[index] += input0[index] * input1[index];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,7 +67,11 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto in_data = reinterpret_cast<int *>(in_tensor->MutableData());
|
||||
auto in_data = reinterpret_cast<int *>(in_tensor->data_c());
|
||||
if (in_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Input data is nullptr";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
int size = in_tensor->ElementsNum();
|
||||
std::vector<int> out_shape(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
|
|
|
@ -51,7 +51,7 @@ int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
}
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto weight_i = inputs_.front();
|
||||
auto weight_i = inputs_[1];
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
|
Loading…
Reference in New Issue