llvm-project/mlir/test/mlir-cpu-runner/cblas_interface.cpp

64 lines
2.4 KiB
C++

//===- cblas_interface.cpp - Simple Blas subset interface -----------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Simple Blas subset interface implementation.
//
//===----------------------------------------------------------------------===//
#include "include/cblas.h"
#include <assert.h>
template <typename T, int N> struct ViewType {
T *data;
unsigned long offset;
unsigned long sizes[N];
unsigned long strides[N];
};
// This is separated out to avoid `unsigned long sizes[0]` which triggers:
// warning: ISO C++ forbids zero-size array [-Wpedantic]
template <typename T> struct ViewType<T, 0> {
T *data;
unsigned long offset;
};
extern "C" void linalg_dot_impl(ViewType<float, 1> *X, ViewType<float, 1> *Y,
ViewType<float, 0> *Z) {
assert(X->sizes[0] == Y->sizes[0] && "Expected X and Y of same size");
*(Z->data + Z->offset) +=
cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
Y->data + Y->offset, Y->strides[0]);
}
extern "C" void linalg_matmul_impl(ViewType<float, 2> *A, ViewType<float, 2> *B,
ViewType<float, 2> *C) {
assert(A->strides[1] == B->strides[1]);
assert(A->strides[1] == C->strides[1]);
assert(A->strides[1] == 1);
assert(A->sizes[0] >= A->strides[1]);
assert(B->sizes[0] >= B->strides[1]);
assert(C->sizes[0] >= C->strides[1]);
assert(C->sizes[0] == A->sizes[0]);
assert(C->sizes[1] == B->sizes[1]);
assert(A->sizes[1] == B->sizes[0]);
cblas_sgemm(CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1],
A->sizes[1], 0.0f, A->data + A->offset, A->strides[0],
B->data + B->offset, B->strides[0], 1.0f, C->data + C->offset,
C->strides[0]);
}