Merge pull request !31484 from caifubi/master-pynative-codex
This commit is contained in:
i-robot 2022-03-21 02:42:56 +00:00 committed by Gitee
commit 2486b17a29
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 16 additions and 13 deletions

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_TRANSDATA_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_TRANSDATA_H_
#include <utility>
#include <vector>
#include <memory>
#include <string>
@ -32,9 +33,9 @@ class AscendLaunchTransData : public AscendLaunchKernel {
total_size_(total_size),
transdata_graph_(nullptr),
input_addr_(nullptr),
src_format_(src_format),
dst_format_(dst_format),
shape_(host_shape) {}
src_format_(std::move(src_format)),
dst_format_(std::move(dst_format)),
shape_(std::move(host_shape)) {}
~AscendLaunchTransData() override = default;

View File

@ -26,6 +26,7 @@
#include "runtime/device/ms_device_shape_transfer.h"
#include "runtime/pynative/op_runtime_info.h"
#include "runtime/pynative/op_executor.h"
#include "runtime/graph_scheduler/actor/actor_common.h"
namespace mindspore::runtime {
namespace {
@ -100,6 +101,7 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
node_address->set_from_persistent_mem(input_tensor->is_parameter());
UpdateRefCount(node_address.get(), true);
}
// The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
@ -163,12 +165,13 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC
const auto &tensor = tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->GetPtr() != nullptr) {
const auto &node_address = AnfAlgo::GetMutableOutputAddr(node, i, false);
MS_EXCEPTION_IF_NULL(node_address);
if (node_address->GetPtr() != nullptr) {
return;
}
tensor->set_device_address(device_tensor);
tensor->set_device_address(node_address);
UpdateRefCount(node_address.get(), true);
CopyTensorDataToDevice(tensor, node, device_context);
}
}
@ -176,14 +179,13 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC
void CopyValueNodeStringToDevice(const ValueNodePtr &node, const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
if (device_tensor->GetPtr() != nullptr) {
const auto &node_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
MS_EXCEPTION_IF_NULL(node_address);
if (node_address->GetPtr() != nullptr) {
return;
}
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
if (!device_context->AllocateMemory(node_address.get(), node_address->GetSize())) {
MS_LOG(EXCEPTION) << "Allocate memory failed";
}
@ -193,7 +195,7 @@ void CopyValueNodeStringToDevice(const ValueNodePtr &node, const device::DeviceC
auto value = GetValue<std::string>(node_value);
size_t tensor_size = value.size();
ShapeVector shape = {1, SizeToLong(tensor_size)};
if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
if (!node_address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
}
}