benchmark_train support multi batch inputs

This commit is contained in:
zhangzhaoju 2022-11-10 11:12:10 +08:00
parent 8edb949d35
commit 1850ab3eab
3 changed files with 69 additions and 35 deletions

View File

@ -11,6 +11,9 @@ set(TEST_SRC
${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
) )
# add static securec link library
include_directories(${TOP_DIR}/cmake/dependency_securec.cmake)
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
set(TEST_SRC set(TEST_SRC
${TEST_SRC} ${TEST_SRC}
@ -30,15 +33,15 @@ endif()
if(PLATFORM_ARM32 OR PLATFORM_ARM64) if(PLATFORM_ARM32 OR PLATFORM_ARM64)
if(SUPPORT_NPU AND ANDROID_STL STREQUAL "c++_static") if(SUPPORT_NPU AND ANDROID_STL STREQUAL "c++_static")
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train c++_shared) target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train c++_shared securec)
else() else()
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train) target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train securec)
endif() endif()
else() else()
if(WIN32) if(WIN32)
target_link_libraries(benchmark_train mindspore-lite_static mindspore-lite-train_static pthread cpu_kernel_mid target_link_libraries(benchmark_train mindspore-lite_static mindspore-lite-train_static pthread cpu_kernel_mid
nnacl_mid train_cpu_kernel_mid) nnacl_mid train_cpu_kernel_mid securec)
else() else()
target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train pthread) target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train pthread securec)
endif() endif()
endif() endif()

View File

@ -27,6 +27,7 @@
#include "tools/benchmark_train/net_runner.h" #include "tools/benchmark_train/net_runner.h"
#include "src/common/common.h" #include "src/common/common.h"
#include "include/api/serialization.h" #include "include/api/serialization.h"
#include "securec/include/securec.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -99,34 +100,30 @@ float *NetTrain::ReadFileBuf(const std::string file, size_t *size) {
return buf.release(); return buf.release();
} }
int NetTrain::GenerateRandomData(mindspore::MSTensor *tensor) {
auto input_data = tensor->MutableData();
if (input_data == nullptr) {
MS_LOG(ERROR) << "MallocData for inTensor failed";
return RET_ERROR;
}
auto tensor_byte_size = tensor->DataSize();
char *casted_data = static_cast<char *>(input_data);
for (size_t i = 0; i < tensor_byte_size; i++) {
casted_data[i] =
(tensor->DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
}
return RET_OK;
}
int NetTrain::GenerateInputData() { int NetTrain::GenerateInputData() {
for (auto tensor : ms_inputs_for_api_) { for (auto tensor : ms_inputs_for_api_) {
auto status = GenerateRandomData(&tensor); auto tensor_byte_size = tensor.DataSize();
if (status != RET_OK) { MS_ASSERT(tensor_byte_size != 0);
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl; auto data_ptr = new (std::nothrow) char[tensor_byte_size];
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed: " << status; if (data_ptr == nullptr) {
return status; MS_LOG(ERROR) << "Malloc input data buffer failed, data_size: " << tensor_byte_size;
return RET_ERROR;
}
inputs_buf_.emplace_back(data_ptr);
inputs_size_.emplace_back(tensor_byte_size);
for (size_t i = 0; i < tensor_byte_size; i++) {
data_ptr[i] =
(tensor.DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
} }
} }
batch_num_ = 1;
return RET_OK; return RET_OK;
} }
int NetTrain::LoadInput() { int NetTrain::LoadInput() {
inputs_buf_.clear();
inputs_size_.clear();
batch_num_ = 0;
if (flags_->in_data_file_.empty()) { if (flags_->in_data_file_.empty()) {
auto status = GenerateInputData(); auto status = GenerateInputData();
if (status != RET_OK) { if (status != RET_OK) {
@ -145,6 +142,23 @@ int NetTrain::LoadInput() {
return RET_OK; return RET_OK;
} }
int NetTrain::LoadStepInput(size_t step) {
if (step >= batch_num_) {
auto cur_batch = step + 1;
MS_LOG(ERROR) << "Max input Batch is:" << batch_num_ << " but got batch :" << cur_batch;
return RET_ERROR;
}
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
auto cur_tensor = ms_inputs_for_api_.at(i);
MS_ASSERT(cur_tensor != nullptr);
auto tensor_data_size = cur_tensor.DataSize();
auto input_data = cur_tensor.MutableData();
MS_ASSERT(input_data != nullptr);
memcpy_s(input_data, tensor_data_size, inputs_buf_[i].get() + step * tensor_data_size, tensor_data_size);
}
return RET_OK;
}
int NetTrain::ReadInputFile() { int NetTrain::ReadInputFile() {
if (ms_inputs_for_api_.empty()) { if (ms_inputs_for_api_.empty()) {
return RET_OK; return RET_OK;
@ -165,17 +179,18 @@ int NetTrain::ReadInputFile() {
return RET_ERROR; return RET_ERROR;
} }
auto tensor_data_size = cur_tensor.DataSize(); auto tensor_data_size = cur_tensor.DataSize();
if (size != tensor_data_size) { MS_ASSERT(tensor_byte_size != 0);
std::cerr << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size if (size == 0 || size % tensor_data_size != 0 || (batch_num_ != 0 && size / tensor_data_size != batch_num_)) {
std::cerr << "Input binary file size error, required :N * " << tensor_data_size << ", in fact: " << size
<< " ,file_name: " << file_name.c_str() << std::endl; << " ,file_name: " << file_name.c_str() << std::endl;
MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size MS_LOG(ERROR) << "Input binary file size error, required: N * " << tensor_data_size << ", in fact: " << size
<< " ,file_name: " << file_name.c_str(); << " ,file_name: " << file_name.c_str();
delete bin_buf; delete bin_buf;
return RET_ERROR; return RET_ERROR;
} }
auto input_data = cur_tensor.MutableData(); inputs_buf_.emplace_back(bin_buf);
memcpy(input_data, bin_buf, tensor_data_size); inputs_size_.emplace_back(size);
delete[](bin_buf); batch_num_ = size / tensor_data_size;
} }
} }
return RET_OK; return RET_OK;
@ -287,11 +302,18 @@ int NetTrain::MarkPerformance() {
for (int i = 0; i < flags_->epochs_; i++) { for (int i = 0; i < flags_->epochs_; i++) {
auto start = GetTimeUs(); auto start = GetTimeUs();
auto status = ms_model_.RunStep(before_call_back_, after_call_back_); for (size_t step = 0; step < batch_num_; step++) {
if (status != mindspore::kSuccess) { MS_LOG(INFO) << "Run for epoch:" << i << " step:" << step;
MS_LOG(ERROR) << "Inference error " << status; auto ret = LoadStepInput(step);
std::cerr << "Inference error " << status; if (ret != RET_OK) {
return RET_ERROR; return ret;
}
auto status = ms_model_.RunStep(before_call_back_, after_call_back_);
if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Inference error " << status;
std::cerr << "Inference error " << status;
return RET_ERROR;
}
} }
auto end = GetTimeUs(); auto end = GetTimeUs();
@ -322,6 +344,10 @@ int NetTrain::MarkPerformance() {
int NetTrain::MarkAccuracy(bool enforce_accuracy) { int NetTrain::MarkAccuracy(bool enforce_accuracy) {
MS_LOG(INFO) << "MarkAccuracy"; MS_LOG(INFO) << "MarkAccuracy";
auto load_ret = LoadStepInput(0);
if (load_ret != RET_OK) {
return load_ret;
}
for (auto &msInput : ms_model_.GetInputs()) { for (auto &msInput : ms_model_.GetInputs()) {
switch (msInput.DataType()) { switch (msInput.DataType()) {
case mindspore::DataType::kNumberTypeFloat32: case mindspore::DataType::kNumberTypeFloat32:

View File

@ -220,6 +220,8 @@ class MS_API NetTrain {
int ReadInputFile(); int ReadInputFile();
int LoadStepInput(size_t step);
void InitMSContext(const std::shared_ptr<Context> &context); void InitMSContext(const std::shared_ptr<Context> &context);
void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg); void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg);
@ -301,6 +303,9 @@ class MS_API NetTrain {
mindspore::MSKernelCallBack after_call_back_{nullptr}; mindspore::MSKernelCallBack after_call_back_{nullptr};
nlohmann::json dump_cfg_json_; nlohmann::json dump_cfg_json_;
std::string dump_file_output_dir_; std::string dump_file_output_dir_;
std::vector<std::shared_ptr<char>> inputs_buf_;
std::vector<size_t> inputs_size_;
size_t batch_num_ = 0;
}; };
int MS_API RunNetTrain(int argc, const char **argv); int MS_API RunNetTrain(int argc, const char **argv);