graph input dynamic

This commit is contained in:
wilfChen 2021-06-09 16:30:09 +08:00
parent 9bd12517fb
commit 2e6afc07ac
2 changed files with 18 additions and 2 deletions

View File

@ -79,6 +79,7 @@
#include "utils/ms_context.h"
#include "utils/context/graph_kernel_flags.h"
#include "utils/utils.h"
#include "abstract/utils.h"
#if ENABLE_CPU && ENABLE_GPU
#include "ps/util.h"
#include "ps/ps_cache/ps_cache_manager.h"
@ -269,6 +270,19 @@ bool UpdatedByAssign(const KernelGraphPtr &kernel_graph, const AnfNodePtr &node)
}
} // namespace
size_t GPUSession::UpdateGraphInputAbstract(AnfNodePtr input_node, tensor::TensorPtr tensor) {
size_t size = LongToSize(tensor->data().nbytes());
if (input_node->isa<Parameter>() && input_node->cast<ParameterPtr>()->is_used_by_dynamic_kernel()) {
auto tensor_shape = tensor->shape();
std::vector<size_t> shape_tmp;
(void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
input_node.get());
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
}
return size;
}
void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
std::vector<tensor::TensorPtr> inputs(inputs_const);
@ -314,8 +328,8 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
auto size = UpdateGraphInputAbstract(input_node, tensor);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), size, tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}

View File

@ -89,6 +89,8 @@ class GPUSession : public SessionBasic {
bool DumpDataEnabledIteration() const;
GraphId CompileGraphImpl(KernelGraphPtr kernel_graph);
size_t UpdateGraphInputAbstract(AnfNodePtr input_node, tensor::TensorPtr tensor);
};
using GPUSessionPtr = std::shared_ptr<GPUSession>;
MS_REG_SESSION(kGPUDevice, GPUSession);