forked from mindspore-Ecosystem/mindspore
!31544 [MSLite][OnDeviceTraining] Fix random initization for labels of TOD
Merge pull request !31544 from lz/tod_fix
This commit is contained in:
commit
9ef9d646ac
|
@ -85,11 +85,16 @@ float *ReadFileBuf(const char *file, size_t *size) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
int NetTrain::GenerateRandomData(size_t size, void *data) {
|
||||
MS_ASSERT(data != nullptr);
|
||||
char *casted_data = static_cast<char *>(data);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
casted_data[i] = static_cast<char>(i);
|
||||
int NetTrain::GenerateRandomData(mindspore::tensor::MSTensor *tensor) {
|
||||
auto input_data = tensor->MutableData();
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "MallocData for inTensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_byte_size = tensor->Size();
|
||||
char *casted_data = static_cast<char *>(input_data);
|
||||
for (size_t i = 0; i < tensor_byte_size; i++) {
|
||||
casted_data[i] = (tensor->data_type() == kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -97,13 +102,7 @@ int NetTrain::GenerateRandomData(size_t size, void *data) {
|
|||
int NetTrain::GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs) {
|
||||
for (auto tensor : *ms_inputs) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
auto input_data = tensor->MutableData();
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "MallocData for inTensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_byte_size = tensor->Size();
|
||||
auto status = GenerateRandomData(tensor_byte_size, input_data);
|
||||
auto status = GenerateRandomData(tensor);
|
||||
if (status != RET_OK) {
|
||||
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
|
||||
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed: " << status;
|
||||
|
|
|
@ -134,7 +134,7 @@ class MS_API NetTrain {
|
|||
// call GenerateRandomData to fill inputTensors
|
||||
int GenerateInputData(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
|
||||
|
||||
int GenerateRandomData(size_t size, void *data);
|
||||
int GenerateRandomData(mindspore::tensor::MSTensor *tensor);
|
||||
|
||||
int ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_inputs);
|
||||
int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, int train_session, int epochs,
|
||||
|
|
Loading…
Reference in New Issue