forked from mindspore-Ecosystem/mindspore
!31484 clean codex
Merge pull request !31484 from caifubi/master-pynative-codex
This commit is contained in:
commit
2486b17a29
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_TRANSDATA_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_TRANSDATA_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -32,9 +33,9 @@ class AscendLaunchTransData : public AscendLaunchKernel {
|
|||
total_size_(total_size),
|
||||
transdata_graph_(nullptr),
|
||||
input_addr_(nullptr),
|
||||
src_format_(src_format),
|
||||
dst_format_(dst_format),
|
||||
shape_(host_shape) {}
|
||||
src_format_(std::move(src_format)),
|
||||
dst_format_(std::move(dst_format)),
|
||||
shape_(std::move(host_shape)) {}
|
||||
|
||||
~AscendLaunchTransData() override = default;
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "runtime/pynative/op_runtime_info.h"
|
||||
#include "runtime/pynative/op_executor.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
|
||||
namespace mindspore::runtime {
|
||||
namespace {
|
||||
|
@ -100,6 +101,7 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
|
|||
input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
|
||||
input_tensor->set_lazy_callback([]() { runtime::OpExecutor::GetInstance().Wait(); });
|
||||
node_address->set_from_persistent_mem(input_tensor->is_parameter());
|
||||
UpdateRefCount(node_address.get(), true);
|
||||
}
|
||||
|
||||
// The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
|
||||
|
@ -163,12 +165,13 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC
|
|||
const auto &tensor = tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
||||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
const auto &node_address = AnfAlgo::GetMutableOutputAddr(node, i, false);
|
||||
MS_EXCEPTION_IF_NULL(node_address);
|
||||
if (node_address->GetPtr() != nullptr) {
|
||||
return;
|
||||
}
|
||||
tensor->set_device_address(device_tensor);
|
||||
tensor->set_device_address(node_address);
|
||||
UpdateRefCount(node_address.get(), true);
|
||||
CopyTensorDataToDevice(tensor, node, device_context);
|
||||
}
|
||||
}
|
||||
|
@ -176,14 +179,13 @@ void CopyValueNodeTensorToDevice(const ValueNodePtr &node, const device::DeviceC
|
|||
void CopyValueNodeStringToDevice(const ValueNodePtr &node, const device::DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
|
||||
if (device_tensor->GetPtr() != nullptr) {
|
||||
const auto &node_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(node_address);
|
||||
if (node_address->GetPtr() != nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
|
||||
if (!device_context->AllocateMemory(node_address.get(), node_address->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Allocate memory failed";
|
||||
}
|
||||
|
||||
|
@ -193,7 +195,7 @@ void CopyValueNodeStringToDevice(const ValueNodePtr &node, const device::DeviceC
|
|||
auto value = GetValue<std::string>(node_value);
|
||||
size_t tensor_size = value.size();
|
||||
ShapeVector shape = {1, SizeToLong(tensor_size)};
|
||||
if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
|
||||
if (!node_address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue