!13087 fix bug for dynamic_batch_size
From: @zhupuxu Reviewed-by: @zhoufeng54,@jjfeing Signed-off-by: @jjfeing
This commit is contained in:
commit
afc4fe8326
|
@ -53,15 +53,18 @@ inline static void PushbackIfNotNull(U *vec, T &&item) {
|
|||
}
|
||||
|
||||
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
|
||||
std::vector<std::vector<int64_t>> *shapes, std::vector<enum DataType> *data_types) {
|
||||
std::vector<std::vector<int64_t>> *shapes, std::vector<enum DataType> *data_types,
|
||||
std::vector<size_t> *mem_sizes) {
|
||||
ClearIfNotNull(names);
|
||||
ClearIfNotNull(shapes);
|
||||
ClearIfNotNull(data_types);
|
||||
ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
|
||||
const auto &info = acl_tensor_list[i];
|
||||
PushbackIfNotNull(names, info.name);
|
||||
PushbackIfNotNull(shapes, info.dims);
|
||||
PushbackIfNotNull(data_types, TransToApiType(info.data_type));
|
||||
PushbackIfNotNull(mem_sizes, info.buffer_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,17 +81,20 @@ static std::string ShapeToString(const std::vector<int64_t> &shape) {
|
|||
}
|
||||
|
||||
Status ModelProcess::ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list,
|
||||
std::vector<MSTensor> *tensor_list, const std::vector<size_t> &mem_sizes) {
|
||||
std::vector<MSTensor> *tensor_list) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
std::vector<std::string> names;
|
||||
std::vector<std::vector<int64_t>> shapes;
|
||||
std::vector<enum DataType> data_types;
|
||||
ConstructTensorDesc(acl_tensor_list, &names, &shapes, &data_types);
|
||||
std::vector<size_t> mem_sizes;
|
||||
|
||||
ConstructTensorDesc(acl_tensor_list, &names, &shapes, &data_types, &mem_sizes);
|
||||
tensor_list->clear();
|
||||
if (names.size() != acl_tensor_list.size() || shapes.size() != acl_tensor_list.size() ||
|
||||
data_types.size() != acl_tensor_list.size()) {
|
||||
data_types.size() != acl_tensor_list.size() || mem_sizes.size() != acl_tensor_list.size()) {
|
||||
MS_LOG(ERROR) << "Inner error, size do not match: names size " << names.size() << " shapes size " << shapes.size()
|
||||
<< " data types size " << data_types.size() << " acl_tensor_list size " << acl_tensor_list.size();
|
||||
<< " data types size " << data_types.size() << " mem sizes size " << mem_sizes.size()
|
||||
<< " acl_tensor_list size " << acl_tensor_list.size();
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
|
@ -96,7 +102,7 @@ Status ModelProcess::ConstructTensors(const std::vector<AclTensorInfo> &acl_tens
|
|||
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
|
||||
tensor_list->emplace_back(names[i], data_types[i], shapes[i], nullptr, mem_sizes[i]);
|
||||
auto ret = aclrtMemcpy((*tensor_list)[i].MutableData(), (*tensor_list)[i].DataSize(),
|
||||
acl_tensor_list[i].device_data, mem_sizes[i], kind);
|
||||
acl_tensor_list[i].device_data, acl_tensor_list[i].buffer_size, kind);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Memcpy input " << i << " from " << (is_run_on_device_ ? "host" : "device")
|
||||
<< " to host failed, memory size " << acl_tensor_list[i].buffer_size;
|
||||
|
@ -159,7 +165,6 @@ Status ModelProcess::InitInputsBuffer() {
|
|||
}
|
||||
MS_LOG(INFO) << "Name of input " << i << " is " << input_name;
|
||||
input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape, input_name});
|
||||
input_size_.push_back(buffer_size);
|
||||
}
|
||||
MS_LOG(INFO) << "Create model inputs success";
|
||||
return kSuccess;
|
||||
|
@ -328,9 +333,8 @@ size_t ModelProcess::GetDynamicDims(const std::vector<AclTensorInfo> &inputs) {
|
|||
Status ModelProcess::SetBatchSize(const std::vector<MSTensor> &inputs) {
|
||||
size_t index;
|
||||
aclError ret;
|
||||
input_size_.clear();
|
||||
for (auto input : inputs) {
|
||||
input_size_.push_back(input.DataSize());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
input_infos_[i].buffer_size = inputs[i].DataSize();
|
||||
}
|
||||
auto *p = reinterpret_cast<const float *>(inputs[inputs.size() - 1].Data().get());
|
||||
MS_EXCEPTION_IF_NULL(p);
|
||||
|
@ -353,12 +357,12 @@ Status ModelProcess::CheckAndInitInput(const std::vector<MSTensor> &inputs) {
|
|||
inputs_ = aclmdlCreateDataset();
|
||||
size_t dynamic_nums = GetDynamicDims(input_infos_);
|
||||
// check inputs
|
||||
if (inputs.size() != input_infos_.size()) {
|
||||
MS_LOG(ERROR) << "Inputs count not match, required count " << input_infos_.size() << ", given count "
|
||||
<< inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
if (dynamic_nums == 0) {
|
||||
if (inputs.size() != input_infos_.size()) {
|
||||
MS_LOG(ERROR) << "Inputs count not match, required count " << input_infos_.size() << ", given count "
|
||||
<< inputs.size();
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
if (inputs[i].Shape() != input_infos_[i].dims) {
|
||||
MS_LOG(INFO) << "Note: input " << i << " shape not match, required " << ShapeToString(input_infos_[i].dims)
|
||||
|
@ -404,18 +408,18 @@ Status ModelProcess::CheckAndInitInput(const std::vector<MSTensor> &inputs) {
|
|||
}
|
||||
}
|
||||
if (dynamic_nums == 1) {
|
||||
if (SetBatchSize(inputs) == kMCDeviceError) {
|
||||
if (SetBatchSize(inputs) != kSuccess) {
|
||||
MS_LOG(ERROR) << "failed to convert dynamic batch size";
|
||||
return kMCDeviceError;
|
||||
}
|
||||
if (ResetOutputSize() != kSuccess) {
|
||||
MS_LOG(ERROR) << "reset output size failed";
|
||||
return kMCDeviceError;
|
||||
}
|
||||
} else if (dynamic_nums == 2) {
|
||||
MS_LOG(ERROR) << "only dynamic batch size is supported";
|
||||
return kMCInvalidInput;
|
||||
}
|
||||
if (ResetOutputSize() == kMCDeviceError) {
|
||||
MS_LOG(ERROR) << "reset output size failed";
|
||||
return kMCDeviceError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
|
@ -435,7 +439,7 @@ Status ModelProcess::ResetOutputSize() {
|
|||
dims *= output_dims.dims[i];
|
||||
}
|
||||
output_type = aclmdlGetOutputDataType(model_desc_, index);
|
||||
output_size_.push_back(dims * aclDataTypeSize(output_type));
|
||||
output_infos_[index].buffer_size = dims * aclDataTypeSize(output_type);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
@ -490,7 +494,7 @@ Status ModelProcess::BuildOutputs(std::vector<MSTensor> *outputs) {
|
|||
}
|
||||
|
||||
std::vector<MSTensor> ModelProcess::GetInputs() {
|
||||
Status ret = ConstructTensors(input_infos_, &input_tensors_, input_size_);
|
||||
Status ret = ConstructTensors(input_infos_, &input_tensors_);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "ConstructTensors failed.";
|
||||
input_tensors_.clear();
|
||||
|
@ -500,7 +504,7 @@ std::vector<MSTensor> ModelProcess::GetInputs() {
|
|||
}
|
||||
|
||||
std::vector<MSTensor> ModelProcess::GetOutputs() {
|
||||
Status ret = ConstructTensors(output_infos_, &output_tensors_, output_size_);
|
||||
Status ret = ConstructTensors(output_infos_, &output_tensors_);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "ConstructTensors failed.";
|
||||
output_tensors_.clear();
|
||||
|
|
|
@ -61,8 +61,7 @@ class ModelProcess {
|
|||
private:
|
||||
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
|
||||
Status CheckAndInitInput(const std::vector<MSTensor> &inputs);
|
||||
Status ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<MSTensor> *tensor_list,
|
||||
const std::vector<size_t> &mem_sizes);
|
||||
Status ConstructTensors(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<MSTensor> *tensor_list);
|
||||
Status BuildOutputs(std::vector<MSTensor> *outputs);
|
||||
Status SetBatchSize(const std::vector<MSTensor> &inputs);
|
||||
Status InitInputsBuffer();
|
||||
|
@ -84,8 +83,6 @@ class ModelProcess {
|
|||
std::vector<AclTensorInfo> output_infos_;
|
||||
std::vector<MSTensor> input_tensors_;
|
||||
std::vector<MSTensor> output_tensors_;
|
||||
std::vector<size_t> output_size_;
|
||||
std::vector<size_t> input_size_;
|
||||
size_t GetDynamicDims(const std::vector<AclTensorInfo> &);
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue