forked from mindspore-Ecosystem/mindspore
!19310 Bert base cluner支持SDK、mxbase推理
Merge pull request !19310 from chenshushu/master
This commit is contained in:
commit
68ddcab992
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
ARG FROM_IMAGE_NAME
|
||||
FROM ${FROM_IMAGE_NAME}
|
||||
|
||||
ARG SDK_PKG
|
||||
|
||||
RUN ln -s /usr/local/python3.7.5/bin/python3.7 /usr/bin/python
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install libglib2.0-dev -y || \
|
||||
rm -rf /var/lib/dpkg/info && \
|
||||
mkdir /var/lib/dpkg/info && \
|
||||
apt-get install libglib2.0-dev -y && \
|
||||
pip install pytest-runner==5.3.0
|
||||
|
||||
# pip install sdk_run
|
||||
COPY $SDK_PKG .
|
||||
RUN ls -hrlt
|
||||
RUN chmod +x ${SDK_PKG} && \
|
||||
./${SDK_PKG} --install-path=/home/run --install && \
|
||||
bash -c "source ~/.bashrc"
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
air_path=$1
|
||||
om_path=$2
|
||||
|
||||
echo "Input AIR file path: ${air_path}"
|
||||
echo "Output OM file path: ${om_path}"
|
||||
|
||||
atc --framework=1 --model="${air_path}" \
|
||||
--output="${om_path}" \
|
||||
--soc_version=Ascend310 \
|
||||
--op_select_implmode="high_precision"
|
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"im_bertbase": {
|
||||
"stream_config": {
|
||||
"deviceId": "0"
|
||||
},
|
||||
"appsrc0": {
|
||||
"props": {
|
||||
"blocksize": "409600"
|
||||
},
|
||||
"factory": "appsrc",
|
||||
"next": "mxpi_tensorinfer0:0"
|
||||
},
|
||||
"appsrc1": {
|
||||
"props": {
|
||||
"blocksize": "409600"
|
||||
},
|
||||
"factory": "appsrc",
|
||||
"next": "mxpi_tensorinfer0:1"
|
||||
},
|
||||
"appsrc2": {
|
||||
"props": {
|
||||
"blocksize": "409600"
|
||||
},
|
||||
"factory": "appsrc",
|
||||
"next": "mxpi_tensorinfer0:2"
|
||||
},
|
||||
"mxpi_tensorinfer0": {
|
||||
"props": {
|
||||
"dataSource": "appsrc0,appsrc1,appsrc2",
|
||||
"modelPath": "../data/model/cluner.om"
|
||||
},
|
||||
"factory": "mxpi_tensorinfer",
|
||||
"next": "mxpi_dataserialize0"
|
||||
},
|
||||
"mxpi_dataserialize0": {
|
||||
"props": {
|
||||
"outputDataKeys": "mxpi_tensorinfer0"
|
||||
},
|
||||
"factory": "mxpi_dataserialize",
|
||||
"next": "appsink0"
|
||||
},
|
||||
"appsink0": {
|
||||
"factory": "appsink"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
address
|
||||
book
|
||||
company
|
||||
game
|
||||
government
|
||||
movie
|
||||
name
|
||||
organization
|
||||
position
|
||||
scene
|
|
@ -0,0 +1,51 @@
|
|||
cmake_minimum_required(VERSION 3.10.0)
|
||||
project(bert)
|
||||
|
||||
set(TARGET bert)
|
||||
|
||||
add_definitions(-DENABLE_DVPP_INTERFACE)
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_definitions(-Dgoogle=mindxsdk_private)
|
||||
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
|
||||
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
|
||||
|
||||
# Check environment variable
|
||||
if(NOT DEFINED ENV{ASCEND_HOME})
|
||||
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
|
||||
endif()
|
||||
if(NOT DEFINED ENV{ASCEND_VERSION})
|
||||
message(WARNING "please define environment variable:ASCEND_VERSION")
|
||||
endif()
|
||||
if(NOT DEFINED ENV{ARCH_PATTERN})
|
||||
message(WARNING "please define environment variable:ARCH_PATTERN")
|
||||
endif()
|
||||
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
|
||||
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
|
||||
|
||||
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
|
||||
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
|
||||
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
|
||||
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
|
||||
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
|
||||
if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
|
||||
set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
|
||||
else()
|
||||
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
|
||||
endif()
|
||||
|
||||
include_directories(${ACL_INC_DIR})
|
||||
include_directories(${OPENSOURCE_DIR}/include)
|
||||
include_directories(${OPENSOURCE_DIR}/include/opencv4)
|
||||
|
||||
include_directories(${MXBASE_INC})
|
||||
include_directories(${MXBASE_POST_PROCESS_DIR})
|
||||
|
||||
link_directories(${ACL_LIB_DIR})
|
||||
link_directories(${OPENSOURCE_DIR}/lib)
|
||||
link_directories(${MXBASE_LIB_DIR})
|
||||
link_directories(${MXBASE_POST_LIB_DIR})
|
||||
|
||||
add_executable(${TARGET} src/main.cpp src/BertNerBase.cpp)
|
||||
target_link_libraries(${TARGET} glog cpprest mxbase opencv_world stdc++fs)
|
||||
|
||||
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
path_cur=$(dirname $0)
|
||||
|
||||
function check_env()
|
||||
{
|
||||
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
|
||||
if [ ! "${ASCEND_VERSION}" ]; then
|
||||
export ASCEND_VERSION=ascend-toolkit/latest
|
||||
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
|
||||
else
|
||||
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
|
||||
fi
|
||||
|
||||
if [ ! "${ARCH_PATTERN}" ]; then
|
||||
# set ARCH_PATTERN to ./ when it was not specified by user
|
||||
export ARCH_PATTERN=./
|
||||
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
|
||||
else
|
||||
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
|
||||
fi
|
||||
}
|
||||
|
||||
function build_bert()
|
||||
{
|
||||
cd $path_cur
|
||||
rm -rf build
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake ..
|
||||
make
|
||||
ret=$?
|
||||
if [ ${ret} -ne 0 ]; then
|
||||
echo "Failed to build bert."
|
||||
exit ${ret}
|
||||
fi
|
||||
make install
|
||||
}
|
||||
|
||||
check_env
|
||||
build_bert
|
|
@ -0,0 +1,345 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "BertNerBase.h"
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include "MxBase/DeviceManager/DeviceManager.h"
|
||||
#include "MxBase/Log/Log.h"
|
||||
|
||||
const uint32_t EACH_LABEL_LENGTH = 4;
|
||||
const uint32_t MAX_LENGTH = 128;
|
||||
const uint32_t CLASS_NUM = 41;
|
||||
|
||||
APP_ERROR BertNerBase::LoadLabels(const std::string &labelPath, std::vector<std::string> *labelMap) {
|
||||
std::ifstream infile;
|
||||
// open label file
|
||||
infile.open(labelPath, std::ios_base::in);
|
||||
std::string s;
|
||||
// check label file validity
|
||||
if (infile.fail()) {
|
||||
LogError << "Failed to open label file: " << labelPath << ".";
|
||||
return APP_ERR_COMM_OPEN_FAIL;
|
||||
}
|
||||
labelMap->clear();
|
||||
// construct label vector
|
||||
while (std::getline(infile, s)) {
|
||||
if (s.size() == 0 || s[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
size_t eraseIndex = s.find_last_not_of("\r\n\t");
|
||||
if (eraseIndex != std::string::npos) {
|
||||
s.erase(eraseIndex + 1, s.size() - eraseIndex);
|
||||
}
|
||||
labelMap->push_back(s);
|
||||
}
|
||||
infile.close();
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::Init(const InitParam &initParam) {
|
||||
deviceId_ = initParam.deviceId;
|
||||
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Init devices failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Set context failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
|
||||
ret = dvppWrapper_->Init();
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "DvppWrapper init failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
|
||||
ret = model_->Init(initParam.modelPath, modelDesc_);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
classNum_ = initParam.classNum;
|
||||
// load labels from file
|
||||
ret = LoadLabels(initParam.labelPath, &labelMap_);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Failed to load labels, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::DeInit() {
|
||||
dvppWrapper_->DeInit();
|
||||
model_->DeInit();
|
||||
MxBase::DeviceManager::GetInstance()->DestroyDevices();
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::ReadTensorFromFile(const std::string &file, uint32_t *data) {
|
||||
if (data == NULL) {
|
||||
LogError << "input data is invalid.";
|
||||
return APP_ERR_COMM_INVALID_POINTER;
|
||||
}
|
||||
std::ifstream infile;
|
||||
// open label file
|
||||
infile.open(file, std::ios_base::in | std::ios_base::binary);
|
||||
// check label file validity
|
||||
if (infile.fail()) {
|
||||
LogError << "Failed to open label file: " << file << ".";
|
||||
return APP_ERR_COMM_OPEN_FAIL;
|
||||
}
|
||||
infile.read(reinterpret_cast<char*>(data), sizeof(uint32_t) * MAX_LENGTH);
|
||||
infile.close();
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::ReadInputTensor(const std::string &fileName, uint32_t index,
|
||||
std::vector<MxBase::TensorBase> *inputs) {
|
||||
uint32_t data[MAX_LENGTH] = {0};
|
||||
APP_ERROR ret = ReadTensorFromFile(fileName, data);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "ReadTensorFromFile failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
const uint32_t dataSize = modelDesc_.inputTensors[index].tensorSize;
|
||||
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
|
||||
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(data), dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
|
||||
ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << GetError(ret) << "Memory malloc and copy failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> shape = {1, MAX_LENGTH};
|
||||
inputs->push_back(MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT32));
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::Inference(const std::vector<MxBase::TensorBase> &inputs,
|
||||
std::vector<MxBase::TensorBase> *outputs) {
|
||||
auto dtypes = model_->GetOutputDataType();
|
||||
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
|
||||
std::vector<uint32_t> shape = {};
|
||||
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
|
||||
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
|
||||
}
|
||||
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
|
||||
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
outputs->push_back(tensor);
|
||||
}
|
||||
|
||||
MxBase::DynamicInfo dynamicInfo = {};
|
||||
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
|
||||
auto startTime = std::chrono::high_resolution_clock::now();
|
||||
APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo);
|
||||
auto endTime = std::chrono::high_resolution_clock::now();
|
||||
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
|
||||
g_inferCost.push_back(costMs);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "ModelInference failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::PostProcess(std::vector<MxBase::TensorBase> *outputs, std::vector<uint32_t> *argmax) {
|
||||
MxBase::TensorBase &tensor = outputs->at(0);
|
||||
APP_ERROR ret = tensor.ToHost();
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << GetError(ret) << "Tensor deploy to host failed.";
|
||||
return ret;
|
||||
}
|
||||
// check tensor is available
|
||||
auto outputShape = tensor.GetShape();
|
||||
uint32_t length = outputShape[0];
|
||||
uint32_t classNum = outputShape[1];
|
||||
LogInfo << "output shape is: " << outputShape[0] << " "<< outputShape[1] << std::endl;
|
||||
|
||||
void* data = tensor.GetBuffer();
|
||||
for (uint32_t i = 0; i < length; i++) {
|
||||
std::vector<float> result = {};
|
||||
for (uint32_t j = 0; j < classNum; j++) {
|
||||
float value = *(reinterpret_cast<float*>(data) + i * classNum + j);
|
||||
result.push_back(value);
|
||||
}
|
||||
// argmax and get the class id
|
||||
std::vector<float>::iterator maxElement = std::max_element(std::begin(result), std::end(result));
|
||||
uint32_t argmaxIndex = maxElement - std::begin(result);
|
||||
argmax->push_back(argmaxIndex);
|
||||
}
|
||||
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::CountPredictResult(const std::string &labelFile, const std::vector<uint32_t> &argmax) {
|
||||
uint32_t data[MAX_LENGTH] = {0};
|
||||
APP_ERROR ret = ReadTensorFromFile(labelFile, data);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "ReadTensorFromFile failed.";
|
||||
return ret;
|
||||
}
|
||||
uint32_t target[CLASS_NUM][MAX_LENGTH] = {0};
|
||||
uint32_t pred[CLASS_NUM][MAX_LENGTH] = {0};
|
||||
for (uint32_t i = 0; i < MAX_LENGTH; i++) {
|
||||
if (data[i] > 0) {
|
||||
target[data[i]][i] = 1;
|
||||
}
|
||||
if (argmax[i] > 0) {
|
||||
pred[argmax[i]][i] = 1;
|
||||
}
|
||||
}
|
||||
for (uint32_t i = 0; i < CLASS_NUM; i++) {
|
||||
for (uint32_t j = 0; j < MAX_LENGTH; j++) {
|
||||
// count True Positive and False Positive
|
||||
if (pred[i][j] == 1) {
|
||||
if (target[i][j] == 1) {
|
||||
g_TP += 1;
|
||||
} else {
|
||||
g_FP += 1;
|
||||
}
|
||||
}
|
||||
// count False Negative
|
||||
if (target[i][j] == 1 && pred[i][j] != 1) {
|
||||
g_FN += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
LogInfo << "TP: " << g_TP << ", FP: " << g_FP << ", FN: " << g_FN;
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
void BertNerBase::GetClunerLabel(const std::vector<uint32_t> &argmax, std::multimap<std::string,
|
||||
std::vector<uint32_t>> *clunerMap) {
|
||||
bool findCluner = false;
|
||||
uint32_t start = 0;
|
||||
std::string clunerName;
|
||||
for (uint32_t i = 0; i < argmax.size(); i++) {
|
||||
if (argmax[i] > 0) {
|
||||
if (!findCluner) {
|
||||
start = i;
|
||||
clunerName = labelMap_[(argmax[i] - 1) / EACH_LABEL_LENGTH];
|
||||
findCluner = true;
|
||||
} else {
|
||||
if (labelMap_[(argmax[i] - 1) / EACH_LABEL_LENGTH] != clunerName) {
|
||||
std::vector<uint32_t> position = {start - 1, i - 2};
|
||||
clunerMap->insert(std::pair<std::string, std::vector<uint32_t>>(clunerName, position));
|
||||
start = i;
|
||||
clunerName = labelMap_[(argmax[i] - 1) / EACH_LABEL_LENGTH];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (findCluner) {
|
||||
std::vector<uint32_t> position = {start - 1, i - 2};
|
||||
clunerMap->insert(std::pair<std::string, std::vector<uint32_t>>(clunerName, position));
|
||||
findCluner = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::WriteResult(const std::string &fileName, const std::vector<uint32_t> &argmax) {
|
||||
std::string resultPathName = "result";
|
||||
// create result directory when it does not exit
|
||||
if (access(resultPathName.c_str(), 0) != 0) {
|
||||
int ret = mkdir(resultPathName.c_str(), S_IRUSR | S_IWUSR | S_IXUSR);
|
||||
if (ret != 0) {
|
||||
LogError << "Failed to create result directory: " << resultPathName << ", ret = " << ret;
|
||||
return APP_ERR_COMM_OPEN_FAIL;
|
||||
}
|
||||
}
|
||||
// create result file under result directory
|
||||
resultPathName = resultPathName + "/result.txt";
|
||||
std::ofstream tfile(resultPathName, std::ofstream::app);
|
||||
if (tfile.fail()) {
|
||||
LogError << "Failed to open result file: " << resultPathName;
|
||||
return APP_ERR_COMM_OPEN_FAIL;
|
||||
}
|
||||
// write inference result into file
|
||||
LogInfo << "==============================================================";
|
||||
LogInfo << "infer result of " << fileName << " is: ";
|
||||
tfile << "file name is: " << fileName << std::endl;
|
||||
std::multimap<std::string, std::vector<uint32_t>> clunerMap;
|
||||
GetClunerLabel(argmax, &clunerMap);
|
||||
for (auto &item : clunerMap) {
|
||||
LogInfo << item.first << ": " << item.second[0] << ", " << item.second[1];
|
||||
tfile << item.first << ": " << item.second[0] << ", " << item.second[1] << std::endl;
|
||||
}
|
||||
LogInfo << "==============================================================";
|
||||
tfile.close();
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR BertNerBase::Process(const std::string &inferPath, const std::string &fileName, bool eval) {
|
||||
std::vector<MxBase::TensorBase> inputs = {};
|
||||
std::string inputIdsFile = inferPath + "00_data/" + fileName;
|
||||
APP_ERROR ret = ReadInputTensor(inputIdsFile, INPUT_IDS, &inputs);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Read input ids failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
std::string inputMaskFile = inferPath + "01_data/" + fileName;
|
||||
ret = ReadInputTensor(inputMaskFile, INPUT_MASK, &inputs);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Read input mask file failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
std::string tokenTypeIdFile = inferPath + "02_data/" + fileName;
|
||||
ret = ReadInputTensor(tokenTypeIdFile, TOKEN_TYPE, &inputs);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Read token typeId file failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<MxBase::TensorBase> outputs = {};
|
||||
ret = Inference(inputs, &outputs);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Inference failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> argmax;
|
||||
ret = PostProcess(&outputs, &argmax);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "PostProcess failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = WriteResult(fileName, argmax);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "save result failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (eval) {
|
||||
std::string labelFile = inferPath + "03_data/" + fileName;
|
||||
ret = CountPredictResult(labelFile, argmax);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "CalcF1Score read label failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
return APP_ERR_OK;
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MXBASE_BERTBASE_H
|
||||
#define MXBASE_BERTBASE_H
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "MxBase/DvppWrapper/DvppWrapper.h"
|
||||
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
|
||||
#include "MxBase/Tensor/TensorContext/TensorContext.h"
|
||||
|
||||
extern std::vector<double> g_inferCost;
|
||||
extern uint32_t g_TP;
|
||||
extern uint32_t g_FP;
|
||||
extern uint32_t g_FN;
|
||||
|
||||
struct InitParam {
|
||||
uint32_t deviceId;
|
||||
std::string labelPath;
|
||||
std::string modelPath;
|
||||
uint32_t classNum;
|
||||
};
|
||||
|
||||
enum DataIndex {
|
||||
INPUT_IDS = 0,
|
||||
INPUT_MASK = 1,
|
||||
TOKEN_TYPE = 2,
|
||||
};
|
||||
|
||||
class BertNerBase {
|
||||
public:
|
||||
APP_ERROR Init(const InitParam &initParam);
|
||||
APP_ERROR DeInit();
|
||||
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> *outputs);
|
||||
APP_ERROR Process(const std::string &inferPath, const std::string &fileName, bool eval);
|
||||
APP_ERROR PostProcess(std::vector<MxBase::TensorBase> *outputs, std::vector<uint32_t> *argmax);
|
||||
protected:
|
||||
APP_ERROR ReadTensorFromFile(const std::string &file, uint32_t *data);
|
||||
APP_ERROR ReadInputTensor(const std::string &fileName, uint32_t index, std::vector<MxBase::TensorBase> *inputs);
|
||||
APP_ERROR LoadLabels(const std::string &labelPath, std::vector<std::string> *labelMap);
|
||||
APP_ERROR ReadInputTensor(const std::string &fileName, const std::vector<uint32_t> &argmax);
|
||||
APP_ERROR WriteResult(const std::string &fileName, const std::vector<uint32_t> &argmax);
|
||||
APP_ERROR CountPredictResult(const std::string &labelFile, const std::vector<uint32_t> &argmax);
|
||||
void GetClunerLabel(const std::vector<uint32_t> &argmax,
|
||||
std::multimap<std::string, std::vector<uint32_t>> *clunerMap);
|
||||
private:
|
||||
std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
|
||||
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
|
||||
MxBase::ModelDesc modelDesc_ = {};
|
||||
std::vector<std::string> labelMap_ = {};
|
||||
uint32_t deviceId_ = 0;
|
||||
uint32_t classNum_ = 0;
|
||||
};
|
||||
#endif
|
|
@ -0,0 +1,109 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <unistd.h>
|
||||
#include <dirent.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include "BertNerBase.h"
|
||||
#include "MxBase/Log/Log.h"
|
||||
|
||||
std::vector<double> g_inferCost;
|
||||
uint32_t g_TP = 0;
|
||||
uint32_t g_FP = 0;
|
||||
uint32_t g_FN = 0;
|
||||
|
||||
void InitBertParam(InitParam* initParam) {
|
||||
initParam->deviceId = 0;
|
||||
initParam->labelPath = "../data/config/infer_label.txt";
|
||||
initParam->modelPath = "../data/model/cluner.om";
|
||||
initParam->classNum = 41;
|
||||
}
|
||||
|
||||
APP_ERROR ReadFilesFromPath(const std::string &path, std::vector<std::string> *files) {
|
||||
DIR *dir = NULL;
|
||||
struct dirent *ptr = NULL;
|
||||
|
||||
if ((dir=opendir(path.c_str())) == NULL) {
|
||||
LogError << "Open dir error: " << path;
|
||||
return APP_ERR_COMM_OPEN_FAIL;
|
||||
}
|
||||
|
||||
while ((ptr=readdir(dir)) != NULL) {
|
||||
// d_type == 8 is file
|
||||
if (ptr->d_type == 8) {
|
||||
files->push_back(ptr->d_name);
|
||||
}
|
||||
}
|
||||
closedir(dir);
|
||||
// sort ascending order
|
||||
sort(files->begin(), files->end());
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc <= 1) {
|
||||
LogWarn << "Please input image path, such as './bert /input/data 0'.";
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
InitParam initParam;
|
||||
InitBertParam(&initParam);
|
||||
auto bertBase = std::make_shared<BertNerBase>();
|
||||
APP_ERROR ret = bertBase->Init(initParam);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Bertbase init failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string inferPath = argv[1];
|
||||
std::vector<std::string> files;
|
||||
ret = ReadFilesFromPath(inferPath + "00_data", &files);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Read files from path failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
// do eval and calc the f1 score
|
||||
bool eval = atoi(argv[2]);
|
||||
for (uint32_t i = 0; i < files.size(); i++) {
|
||||
LogInfo << "read file name: " << files[i];
|
||||
ret = bertBase->Process(inferPath, files[i], eval);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Bertbase process failed, ret=" << ret << ".";
|
||||
bertBase->DeInit();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (eval) {
|
||||
LogInfo << "==============================================================";
|
||||
float precision = g_TP * 1.0 / (g_TP + g_FP);
|
||||
LogInfo << "Precision: " << precision;
|
||||
float recall = g_TP * 1.0 / (g_TP + g_FN);
|
||||
LogInfo << "recall: " << recall;
|
||||
LogInfo << "F1 Score: " << 2 * precision * recall / (precision + recall);
|
||||
LogInfo << "==============================================================";
|
||||
}
|
||||
bertBase->DeInit();
|
||||
double costSum = 0;
|
||||
for (uint32_t i = 0; i < g_inferCost.size(); i++) {
|
||||
costSum += g_inferCost[i];
|
||||
}
|
||||
LogInfo << "Infer images sum " << g_inferCost.size() << ", cost total time: " << costSum << " ms.";
|
||||
LogInfo << "The throughput: " << g_inferCost.size() * 1000 / costSum << " bin/sec.";
|
||||
return APP_ERR_OK;
|
||||
}
|
|
@ -0,0 +1,284 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
sample script of CLUE infer using SDK run in docker
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
|
||||
import MxpiDataType_pb2 as MxpiDataType
|
||||
import numpy as np
|
||||
from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, \
|
||||
MxProtobufIn, StringVector
|
||||
|
||||
TP = 0
|
||||
FP = 0
|
||||
FN = 0
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""set and check parameters."""
|
||||
parser = argparse.ArgumentParser(description="bert process")
|
||||
parser.add_argument("--pipeline", type=str, default="", help="SDK infer pipeline")
|
||||
parser.add_argument("--data_dir", type=str, default="",
|
||||
help="Dataset contain input_ids, input_mask, segment_ids, label_ids")
|
||||
parser.add_argument("--label_file", type=str, default="", help="label ids to name")
|
||||
parser.add_argument("--output_file", type=str, default="", help="save result to file")
|
||||
parser.add_argument("--f1_method", type=str, default="BF1", help="calc F1 use the number label,(BF1, MF1)")
|
||||
parser.add_argument("--do_eval", type=bool, default=False, help="eval the accuracy of model ")
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
def send_source_data(appsrc_id, filename, stream_name, stream_manager):
|
||||
"""
|
||||
Construct the input of the stream,
|
||||
send inputs data to a specified stream based on streamName.
|
||||
|
||||
Returns:
|
||||
bool: send data success or not
|
||||
"""
|
||||
tensor = np.fromfile(filename, dtype=np.int32)
|
||||
tensor = np.expand_dims(tensor, 0)
|
||||
tensor_package_list = MxpiDataType.MxpiTensorPackageList()
|
||||
tensor_package = tensor_package_list.tensorPackageVec.add()
|
||||
array_bytes = tensor.tobytes()
|
||||
data_input = MxDataInput()
|
||||
data_input.data = array_bytes
|
||||
tensor_vec = tensor_package.tensorVec.add()
|
||||
tensor_vec.deviceId = 0
|
||||
tensor_vec.memType = 0
|
||||
for i in tensor.shape:
|
||||
tensor_vec.tensorShape.append(i)
|
||||
tensor_vec.dataStr = data_input.data
|
||||
tensor_vec.tensorDataSize = len(array_bytes)
|
||||
|
||||
key = "appsrc{}".format(appsrc_id).encode('utf-8')
|
||||
protobuf_vec = InProtobufVector()
|
||||
protobuf = MxProtobufIn()
|
||||
protobuf.key = key
|
||||
protobuf.type = b'MxTools.MxpiTensorPackageList'
|
||||
protobuf.protobuf = tensor_package_list.SerializeToString()
|
||||
protobuf_vec.push_back(protobuf)
|
||||
|
||||
ret = stream_manager.SendProtobuf(stream_name, appsrc_id, protobuf_vec)
|
||||
if ret < 0:
|
||||
print("Failed to send data to stream.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def send_appsrc_data(args, file_name, stream_name, stream_manager):
|
||||
"""
|
||||
send three stream to infer model, include input ids, input mask and token type_id.
|
||||
|
||||
Returns:
|
||||
bool: send data success or not
|
||||
"""
|
||||
input_ids = os.path.realpath(os.path.join(args.data_dir, "00_data", file_name))
|
||||
if not send_source_data(0, input_ids, stream_name, stream_manager):
|
||||
return False
|
||||
input_mask = os.path.realpath(os.path.join(args.data_dir, "01_data", file_name))
|
||||
if not send_source_data(1, input_mask, stream_name, stream_manager):
|
||||
return False
|
||||
token_type_id = os.path.realpath(os.path.join(args.data_dir, "02_data", file_name))
|
||||
if not send_source_data(2, token_type_id, stream_name, stream_manager):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def read_label_file(label_file):
|
||||
"""
|
||||
Args:
|
||||
label_file:
|
||||
"address"
|
||||
"book"
|
||||
...
|
||||
Returns:
|
||||
label list
|
||||
"""
|
||||
return open(label_file).readlines()
|
||||
|
||||
|
||||
def process_infer_to_cluner(args, logit_id, each_label_length=4):
|
||||
"""
|
||||
find label and position from the logit_id tensor.
|
||||
|
||||
Args:
|
||||
args: param of config.
|
||||
logit_id: shape is [128], example: [0..32.34..0].
|
||||
each_label_length: each label have 4 prefix, ["S_", "B_", "M_", "E_"].
|
||||
|
||||
Returns:
|
||||
dict of visualization result, as 'position': [9, 10]
|
||||
"""
|
||||
label_list = read_label_file(os.path.realpath(args.label_file))
|
||||
find_cluner = False
|
||||
result_list = []
|
||||
for i, value in enumerate(logit_id):
|
||||
if value > 0:
|
||||
if not find_cluner:
|
||||
start = i
|
||||
cluner_name = label_list[(value - 1) // each_label_length]
|
||||
find_cluner = True
|
||||
else:
|
||||
if label_list[(value - 1) // each_label_length] != cluner_name:
|
||||
item = {}
|
||||
item[cluner_name] = [start - 1, i - 2]
|
||||
result_list.append(item)
|
||||
start = i
|
||||
cluner_name = label_list[(value - 1) // each_label_length]
|
||||
else:
|
||||
if find_cluner:
|
||||
item = {}
|
||||
item[cluner_name] = [start - 1, i - 2]
|
||||
result_list.append(item)
|
||||
find_cluner = False
|
||||
|
||||
return result_list
|
||||
|
||||
|
||||
def count_pred_result(args, file_name, logit_id, class_num=41, max_seq_length=128):
|
||||
"""
|
||||
support two method to calc f1 sore, if dataset has two class, suggest using BF1,
|
||||
else more than two class, suggest using MF1.
|
||||
Args:
|
||||
args: param of config.
|
||||
file_name: label file name.
|
||||
logit_id: output tensor of infer.
|
||||
class_num: cluner data default is 41.
|
||||
max_seq_length: sentence input length default is 128.
|
||||
|
||||
global:
|
||||
TP: pred == target
|
||||
FP: in pred but not in target
|
||||
FN: in target but not in pred
|
||||
"""
|
||||
label_file = os.path.realpath(os.path.join(args.data_dir, "03_data", file_name))
|
||||
label_ids = np.fromfile(label_file, np.int32)
|
||||
label_ids.reshape(max_seq_length, -1)
|
||||
global TP, FP, FN
|
||||
if args.f1_method == "BF1":
|
||||
pos_eva = np.isin(logit_id, [i for i in range(1, class_num)])
|
||||
pos_label = np.isin(label_ids, [i for i in range(1, class_num)])
|
||||
TP += np.sum(pos_eva & pos_label)
|
||||
FP += np.sum(pos_eva & (~pos_label))
|
||||
FN += np.sum((~pos_eva) & pos_label)
|
||||
else:
|
||||
target = np.zeros((len(label_ids), class_num), dtype=np.int32)
|
||||
pred = np.zeros((len(logit_id), class_num), dtype=np.int32)
|
||||
for i, label in enumerate(label_ids):
|
||||
if label > 0:
|
||||
target[i][label] = 1
|
||||
for i, label in enumerate(logit_id):
|
||||
if label > 0:
|
||||
pred[i][label] = 1
|
||||
target = target.reshape(class_num, -1)
|
||||
pred = pred.reshape(class_num, -1)
|
||||
for i in range(0, class_num):
|
||||
for j in range(0, max_seq_length):
|
||||
if pred[i][j] == 1:
|
||||
if target[i][j] == 1:
|
||||
TP += 1
|
||||
else:
|
||||
FP += 1
|
||||
if target[i][j] == 1 and pred[i][j] != 1:
|
||||
FN += 1
|
||||
|
||||
|
||||
def post_process(args, file_name, infer_result, max_seq_length=128):
|
||||
"""
|
||||
process the result of infer tensor to Visualization results.
|
||||
Args:
|
||||
args: param of config.
|
||||
file_name: label file name.
|
||||
infer_result: get logit from infer result
|
||||
max_seq_length: sentence input length default is 128.
|
||||
"""
|
||||
# print the infer result
|
||||
print("==============================================================")
|
||||
result = MxpiDataType.MxpiTensorPackageList()
|
||||
result.ParseFromString(infer_result[0].messageBuf)
|
||||
res = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype='<f4')
|
||||
res = res.reshape(max_seq_length, -1)
|
||||
print("output tensor is: ", res.shape)
|
||||
|
||||
logit_id = np.argmax(res, axis=-1)
|
||||
logit_id = np.reshape(logit_id, -1)
|
||||
cluner_list = process_infer_to_cluner(args, logit_id)
|
||||
print(cluner_list)
|
||||
with open(args.output_file, "a") as file:
|
||||
file.write("{}: {}\n".format(file_name, str(cluner_list)))
|
||||
|
||||
if args.do_eval:
|
||||
count_pred_result(args, file_name, logit_id)
|
||||
|
||||
|
||||
def run():
|
||||
"""
|
||||
read pipeline and do infer
|
||||
"""
|
||||
args = parse_args()
|
||||
# init stream manager
|
||||
stream_manager_api = StreamManagerApi()
|
||||
ret = stream_manager_api.InitManager()
|
||||
if ret != 0:
|
||||
print("Failed to init Stream manager, ret=%s" % str(ret))
|
||||
return
|
||||
|
||||
# create streams by pipeline config file
|
||||
with open(os.path.realpath(args.pipeline), 'rb') as f:
|
||||
pipeline_str = f.read()
|
||||
ret = stream_manager_api.CreateMultipleStreams(pipeline_str)
|
||||
if ret != 0:
|
||||
print("Failed to create Stream, ret=%s" % str(ret))
|
||||
return
|
||||
|
||||
stream_name = b'im_bertbase'
|
||||
# input_ids file list, every file content a tensor[1,128]
|
||||
file_list = glob.glob(os.path.join(os.path.realpath(args.data_dir), "00_data", "*.bin"))
|
||||
for input_ids in file_list:
|
||||
file_name = input_ids.split('/')[-1]
|
||||
if not send_appsrc_data(args, file_name, stream_name, stream_manager_api):
|
||||
return
|
||||
# Obtain the inference result by specifying streamName and uniqueId.
|
||||
key_vec = StringVector()
|
||||
key_vec.push_back(b'mxpi_tensorinfer0')
|
||||
infer_result = stream_manager_api.GetProtobuf(stream_name, 0, key_vec)
|
||||
if infer_result.size() == 0:
|
||||
print("inferResult is null")
|
||||
return
|
||||
if infer_result[0].errorCode != 0:
|
||||
print("GetProtobuf error. errorCode=%d" % (infer_result[0].errorCode))
|
||||
return
|
||||
post_process(args, file_name, infer_result)
|
||||
|
||||
if args.do_eval:
|
||||
print("==============================================================")
|
||||
precision = TP / (TP + FP)
|
||||
print("Precision {:.6f} ".format(precision))
|
||||
recall = TP / (TP + FN)
|
||||
print("Recall {:.6f} ".format(recall))
|
||||
print("F1 {:.6f} ".format(2 * precision * recall / (precision + recall)))
|
||||
print("==============================================================")
|
||||
# destroy streams
|
||||
stream_manager_api.DestroyAllStreams()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -e
|
||||
|
||||
# Simple log helper functions
|
||||
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
|
||||
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
|
||||
|
||||
#export MX_SDK_HOME=/home/work/mxVision
|
||||
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
|
||||
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
|
||||
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
|
||||
|
||||
#to set PYTHONPATH, import the StreamManagerApi.py
|
||||
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
|
||||
|
||||
python3.7 main.py --pipeline=../data/config/bert_base.pipeline --data_dir=../data/input --label_file=../data/config/infer_label.txt --output_file=./output.txt --do_eval=True --f1_method=MF1
|
||||
exit 0
|
|
@ -0,0 +1,37 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
docker_image=$1
|
||||
data_dir=$2
|
||||
model_dir=$3
|
||||
|
||||
docker run -it --ipc=host \
|
||||
--device=/dev/davinci0 \
|
||||
--device=/dev/davinci1 \
|
||||
--device=/dev/davinci2 \
|
||||
--device=/dev/davinci3 \
|
||||
--device=/dev/davinci4 \
|
||||
--device=/dev/davinci5 \
|
||||
--device=/dev/davinci6 \
|
||||
--device=/dev/davinci7 \
|
||||
--device=/dev/davinci_manager \
|
||||
--device=/dev/devmm_svm \
|
||||
--device=/dev/hisi_hdc \
|
||||
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons \
|
||||
-v ${data_dir}:${data_dir} \
|
||||
-v ${model_dir}:${model_dir} \
|
||||
-v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash
|
Loading…
Reference in New Issue