!31544 [MSLite][OnDeviceTraining] Fix random initization for labels of TOD

Merge pull request !31544 from lz/tod_fix
This commit is contained in:
i-robot 2022-03-21 03:07:02 +00:00 committed by Gitee
commit 9ef9d646ac
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 12 additions and 13 deletions

View File

@ -85,11 +85,16 @@ float *ReadFileBuf(const char *file, size_t *size) {
}
} // namespace
int NetTrain::GenerateRandomData(size_t size, void *data) {
MS_ASSERT(data != nullptr);
char *casted_data = static_cast<char *>(data);
for (size_t i = 0; i < size; i++) {
casted_data[i] = static_cast<char>(i);
int NetTrain::GenerateRandomData(mindspore::tensor::MSTensor *tensor) {
auto input_data = tensor->MutableData();
if (input_data == nullptr) {
MS_LOG(ERROR) << "MallocData for inTensor failed";
return RET_ERROR;
}
auto tensor_byte_size = tensor->Size();
char *casted_data = static_cast<char *>(input_data);
for (size_t i = 0; i < tensor_byte_size; i++) {
casted_data[i] = (tensor->data_type() == kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
}
return RET_OK;
}
@ -97,13 +102,7 @@ int NetTrain::GenerateRandomData(size_t size, void *data) {
int NetTrain::GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
for (auto tensor : *ms_inputs) {
MS_ASSERT(tensor != nullptr);
auto input_data = tensor->MutableData();
if (input_data == nullptr) {
MS_LOG(ERROR) << "MallocData for inTensor failed";
return RET_ERROR;
}
auto tensor_byte_size = tensor->Size();
auto status = GenerateRandomData(tensor_byte_size, input_data);
auto status = GenerateRandomData(tensor);
if (status != RET_OK) {
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed: " << status;

View File

@ -134,7 +134,7 @@ class MS_API NetTrain {
// call GenerateRandomData to fill inputTensors
int GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
int GenerateRandomData(size_t size, void *data);
int GenerateRandomData(mindspore::tensor::MSTensor *tensor);
int ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs,