benchmark_train support multi batch inputs
This commit is contained in:
parent
8edb949d35
commit
1850ab3eab
|
@ -11,6 +11,9 @@ set(TEST_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add static securec link library
|
||||||
|
include_directories(${TOP_DIR}/cmake/dependency_securec.cmake)
|
||||||
|
|
||||||
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
||||||
set(TEST_SRC
|
set(TEST_SRC
|
||||||
${TEST_SRC}
|
${TEST_SRC}
|
||||||
|
@ -30,15 +33,15 @@ endif()
|
||||||
|
|
||||||
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
|
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||||
if(SUPPORT_NPU AND ANDROID_STL STREQUAL "c++_static")
|
if(SUPPORT_NPU AND ANDROID_STL STREQUAL "c++_static")
|
||||||
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train c++_shared)
|
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train c++_shared securec)
|
||||||
else()
|
else()
|
||||||
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train)
|
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train securec)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
target_link_libraries(benchmark_train mindspore-lite_static mindspore-lite-train_static pthread cpu_kernel_mid
|
target_link_libraries(benchmark_train mindspore-lite_static mindspore-lite-train_static pthread cpu_kernel_mid
|
||||||
nnacl_mid train_cpu_kernel_mid)
|
nnacl_mid train_cpu_kernel_mid securec)
|
||||||
else()
|
else()
|
||||||
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train pthread)
|
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train pthread securec)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "tools/benchmark_train/net_runner.h"
|
#include "tools/benchmark_train/net_runner.h"
|
||||||
#include "src/common/common.h"
|
#include "src/common/common.h"
|
||||||
#include "include/api/serialization.h"
|
#include "include/api/serialization.h"
|
||||||
|
#include "securec/include/securec.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -99,34 +100,30 @@ float *NetTrain::ReadFileBuf(const std::string file, size_t *size) {
|
||||||
return buf.release();
|
return buf.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
int NetTrain::GenerateRandomData(mindspore::MSTensor *tensor) {
|
|
||||||
auto input_data = tensor->MutableData();
|
|
||||||
if (input_data == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "MallocData for inTensor failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
auto tensor_byte_size = tensor->DataSize();
|
|
||||||
char *casted_data = static_cast<char *>(input_data);
|
|
||||||
for (size_t i = 0; i < tensor_byte_size; i++) {
|
|
||||||
casted_data[i] =
|
|
||||||
(tensor->DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
int NetTrain::GenerateInputData() {
|
int NetTrain::GenerateInputData() {
|
||||||
for (auto tensor : ms_inputs_for_api_) {
|
for (auto tensor : ms_inputs_for_api_) {
|
||||||
auto status = GenerateRandomData(&tensor);
|
auto tensor_byte_size = tensor.DataSize();
|
||||||
if (status != RET_OK) {
|
MS_ASSERT(tensor_byte_size != 0);
|
||||||
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
|
auto data_ptr = new (std::nothrow) char[tensor_byte_size];
|
||||||
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed: " << status;
|
if (data_ptr == nullptr) {
|
||||||
return status;
|
MS_LOG(ERROR) << "Malloc input data buffer failed, data_size: " << tensor_byte_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
inputs_buf_.emplace_back(data_ptr);
|
||||||
|
inputs_size_.emplace_back(tensor_byte_size);
|
||||||
|
for (size_t i = 0; i < tensor_byte_size; i++) {
|
||||||
|
data_ptr[i] =
|
||||||
|
(tensor.DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
batch_num_ = 1;
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int NetTrain::LoadInput() {
|
int NetTrain::LoadInput() {
|
||||||
|
inputs_buf_.clear();
|
||||||
|
inputs_size_.clear();
|
||||||
|
batch_num_ = 0;
|
||||||
if (flags_->in_data_file_.empty()) {
|
if (flags_->in_data_file_.empty()) {
|
||||||
auto status = GenerateInputData();
|
auto status = GenerateInputData();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
|
@ -145,6 +142,23 @@ int NetTrain::LoadInput() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int NetTrain::LoadStepInput(size_t step) {
|
||||||
|
if (step >= batch_num_) {
|
||||||
|
auto cur_batch = step + 1;
|
||||||
|
MS_LOG(ERROR) << "Max input Batch is:" << batch_num_ << " but got batch :" << cur_batch;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
|
||||||
|
auto cur_tensor = ms_inputs_for_api_.at(i);
|
||||||
|
MS_ASSERT(cur_tensor != nullptr);
|
||||||
|
auto tensor_data_size = cur_tensor.DataSize();
|
||||||
|
auto input_data = cur_tensor.MutableData();
|
||||||
|
MS_ASSERT(input_data != nullptr);
|
||||||
|
memcpy_s(input_data, tensor_data_size, inputs_buf_[i].get() + step * tensor_data_size, tensor_data_size);
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int NetTrain::ReadInputFile() {
|
int NetTrain::ReadInputFile() {
|
||||||
if (ms_inputs_for_api_.empty()) {
|
if (ms_inputs_for_api_.empty()) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -165,17 +179,18 @@ int NetTrain::ReadInputFile() {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto tensor_data_size = cur_tensor.DataSize();
|
auto tensor_data_size = cur_tensor.DataSize();
|
||||||
if (size != tensor_data_size) {
|
MS_ASSERT(tensor_byte_size != 0);
|
||||||
std::cerr << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
|
if (size == 0 || size % tensor_data_size != 0 || (batch_num_ != 0 && size / tensor_data_size != batch_num_)) {
|
||||||
|
std::cerr << "Input binary file size error, required :N * " << tensor_data_size << ", in fact: " << size
|
||||||
<< " ,file_name: " << file_name.c_str() << std::endl;
|
<< " ,file_name: " << file_name.c_str() << std::endl;
|
||||||
MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size
|
MS_LOG(ERROR) << "Input binary file size error, required: N * " << tensor_data_size << ", in fact: " << size
|
||||||
<< " ,file_name: " << file_name.c_str();
|
<< " ,file_name: " << file_name.c_str();
|
||||||
delete bin_buf;
|
delete bin_buf;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto input_data = cur_tensor.MutableData();
|
inputs_buf_.emplace_back(bin_buf);
|
||||||
memcpy(input_data, bin_buf, tensor_data_size);
|
inputs_size_.emplace_back(size);
|
||||||
delete[](bin_buf);
|
batch_num_ = size / tensor_data_size;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -287,11 +302,18 @@ int NetTrain::MarkPerformance() {
|
||||||
|
|
||||||
for (int i = 0; i < flags_->epochs_; i++) {
|
for (int i = 0; i < flags_->epochs_; i++) {
|
||||||
auto start = GetTimeUs();
|
auto start = GetTimeUs();
|
||||||
auto status = ms_model_.RunStep(before_call_back_, after_call_back_);
|
for (size_t step = 0; step < batch_num_; step++) {
|
||||||
if (status != mindspore::kSuccess) {
|
MS_LOG(INFO) << "Run for epoch:" << i << " step:" << step;
|
||||||
MS_LOG(ERROR) << "Inference error " << status;
|
auto ret = LoadStepInput(step);
|
||||||
std::cerr << "Inference error " << status;
|
if (ret != RET_OK) {
|
||||||
return RET_ERROR;
|
return ret;
|
||||||
|
}
|
||||||
|
auto status = ms_model_.RunStep(before_call_back_, after_call_back_);
|
||||||
|
if (status != mindspore::kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Inference error " << status;
|
||||||
|
std::cerr << "Inference error " << status;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto end = GetTimeUs();
|
auto end = GetTimeUs();
|
||||||
|
@ -322,6 +344,10 @@ int NetTrain::MarkPerformance() {
|
||||||
|
|
||||||
int NetTrain::MarkAccuracy(bool enforce_accuracy) {
|
int NetTrain::MarkAccuracy(bool enforce_accuracy) {
|
||||||
MS_LOG(INFO) << "MarkAccuracy";
|
MS_LOG(INFO) << "MarkAccuracy";
|
||||||
|
auto load_ret = LoadStepInput(0);
|
||||||
|
if (load_ret != RET_OK) {
|
||||||
|
return load_ret;
|
||||||
|
}
|
||||||
for (auto &msInput : ms_model_.GetInputs()) {
|
for (auto &msInput : ms_model_.GetInputs()) {
|
||||||
switch (msInput.DataType()) {
|
switch (msInput.DataType()) {
|
||||||
case mindspore::DataType::kNumberTypeFloat32:
|
case mindspore::DataType::kNumberTypeFloat32:
|
||||||
|
|
|
@ -220,6 +220,8 @@ class MS_API NetTrain {
|
||||||
|
|
||||||
int ReadInputFile();
|
int ReadInputFile();
|
||||||
|
|
||||||
|
int LoadStepInput(size_t step);
|
||||||
|
|
||||||
void InitMSContext(const std::shared_ptr<Context> &context);
|
void InitMSContext(const std::shared_ptr<Context> &context);
|
||||||
|
|
||||||
void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg);
|
void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg);
|
||||||
|
@ -301,6 +303,9 @@ class MS_API NetTrain {
|
||||||
mindspore::MSKernelCallBack after_call_back_{nullptr};
|
mindspore::MSKernelCallBack after_call_back_{nullptr};
|
||||||
nlohmann::json dump_cfg_json_;
|
nlohmann::json dump_cfg_json_;
|
||||||
std::string dump_file_output_dir_;
|
std::string dump_file_output_dir_;
|
||||||
|
std::vector<std::shared_ptr<char>> inputs_buf_;
|
||||||
|
std::vector<size_t> inputs_size_;
|
||||||
|
size_t batch_num_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
int MS_API RunNetTrain(int argc, const char **argv);
|
int MS_API RunNetTrain(int argc, const char **argv);
|
||||||
|
|
Loading…
Reference in New Issue