forked from mindspore-Ecosystem/mindspore
graph input dynamic
This commit is contained in:
parent
9bd12517fb
commit
2e6afc07ac
|
@ -79,6 +79,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "utils/utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
|
@ -269,6 +270,19 @@ bool UpdatedByAssign(const KernelGraphPtr &kernel_graph, const AnfNodePtr &node)
|
|||
}
|
||||
} // namespace
|
||||
|
||||
size_t GPUSession::UpdateGraphInputAbstract(AnfNodePtr input_node, tensor::TensorPtr tensor) {
|
||||
size_t size = LongToSize(tensor->data().nbytes());
|
||||
if (input_node->isa<Parameter>() && input_node->cast<ParameterPtr>()->is_used_by_dynamic_kernel()) {
|
||||
auto tensor_shape = tensor->shape();
|
||||
std::vector<size_t> shape_tmp;
|
||||
(void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
|
||||
input_node.get());
|
||||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const {
|
||||
std::vector<tensor::TensorPtr> inputs(inputs_const);
|
||||
|
@ -314,8 +328,8 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
tensor->set_device_address(device_address);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
auto size = UpdateGraphInputAbstract(input_node, tensor);
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), size, tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
|
|
|
@ -89,6 +89,8 @@ class GPUSession : public SessionBasic {
|
|||
bool DumpDataEnabledIteration() const;
|
||||
|
||||
GraphId CompileGraphImpl(KernelGraphPtr kernel_graph);
|
||||
|
||||
size_t UpdateGraphInputAbstract(AnfNodePtr input_node, tensor::TensorPtr tensor);
|
||||
};
|
||||
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
||||
MS_REG_SESSION(kGPUDevice, GPUSession);
|
||||
|
|
Loading…
Reference in New Issue