!26742 统一运行时场景支持输入数据shape动态时shape推导

Merge pull request !26742 from chengbin/master
This commit is contained in:
i-robot 2021-12-02 10:59:31 +00:00 committed by Gitee
commit 00c8e9a964
2 changed files with 31 additions and 0 deletions

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include <algorithm>
#include "runtime/framework/actor/data_prepare_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
@ -120,6 +122,28 @@ void DataPrepareActor::Init() {
}
}
void DataPrepareActor::UpdateDynamicShape(const AnfNodePtr &input_node, const TensorPtr &input_tensor) {
MS_EXCEPTION_IF_NULL(input_node);
if (input_tensor == nullptr) {
return;
}
if (!input_node->isa<Parameter>()) {
return;
}
auto input_param = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(input_param);
if (!input_param->has_dynamic_shape()) {
return;
}
auto shape = input_tensor->shape();
std::vector<size_t> shape_tmp;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_tmp), IntToSize);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, input_node.get());
}
void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
@ -229,6 +253,9 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
if (!IsHostQueueDSActor(input_node, graph, graph_compiler_info_->origin_parameters_order_, strategy_)) {
continue;
}
UpdateDynamicShape(input_node, input_tensor);
auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
if (tensor_position >= host_tensors.size()) {
std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
@ -277,6 +304,8 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
continue;
}
UpdateDynamicShape(input_node, input_tensor);
if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
if (tensor_position >= host_tensors.size()) {

View File

@ -68,6 +68,8 @@ class DataPrepareActor : public DebugAwareActor {
private:
friend class GraphScheduler;
void UpdateDynamicShape(const AnfNodePtr &input_node, const TensorPtr &input_tensor);
void PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
OpContext<DeviceTensor> *const context);
void PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,