!26742 统一运行时场景支持输入数据shape动态时shape推导
Merge pull request !26742 from chengbin/master
This commit is contained in:
commit
00c8e9a964
|
@ -14,6 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "runtime/framework/actor/data_prepare_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "runtime/framework/actor/kernel_actor.h"
|
||||
|
@ -120,6 +122,28 @@ void DataPrepareActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
void DataPrepareActor::UpdateDynamicShape(const AnfNodePtr &input_node, const TensorPtr &input_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto input_param = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_param);
|
||||
if (!input_param->has_dynamic_shape()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape = input_tensor->shape();
|
||||
std::vector<size_t> shape_tmp;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_tmp), IntToSize);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, input_node.get());
|
||||
}
|
||||
|
||||
void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
@ -229,6 +253,9 @@ void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vect
|
|||
if (!IsHostQueueDSActor(input_node, graph, graph_compiler_info_->origin_parameters_order_, strategy_)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
UpdateDynamicShape(input_node, input_tensor);
|
||||
|
||||
auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
|
||||
if (tensor_position >= host_tensors.size()) {
|
||||
std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
|
||||
|
@ -277,6 +304,8 @@ void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<Tens
|
|||
continue;
|
||||
}
|
||||
|
||||
UpdateDynamicShape(input_node, input_tensor);
|
||||
|
||||
if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
|
||||
auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
|
||||
if (tensor_position >= host_tensors.size()) {
|
||||
|
|
|
@ -68,6 +68,8 @@ class DataPrepareActor : public DebugAwareActor {
|
|||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void UpdateDynamicShape(const AnfNodePtr &input_node, const TensorPtr &input_tensor);
|
||||
|
||||
void PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
void PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,
|
||||
|
|
Loading…
Reference in New Issue