!27700 Add host shape to DeviceAddress

Merge pull request !27700 from caifubi/master-pynative-mindrt-bugfix-host-shape
This commit is contained in:
i-robot 2021-12-16 11:51:58 +00:00 committed by Gitee
commit 7caed8835f
3 changed files with 3 additions and 0 deletions

View File

@ -101,6 +101,7 @@ class DeviceAddress : public mindspore::DeviceSync {
bool is_ptr_persisted() const { return is_ptr_persisted_; }
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
ShapeVector host_shape() const { return host_shape_; }
bool from_persistent_mem() const { return from_persistent_mem_; }
void set_from_persistent_mem(bool from_persistent_mem) { from_persistent_mem_ = from_persistent_mem; }
virtual void set_status(DeviceAddressStatus status) {}

View File

@ -195,6 +195,7 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type);
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
if (is_gradient_out) {
device_address->set_from_persistent_mem(true);
}

View File

@ -250,6 +250,7 @@ void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *d
auto new_device_address = device_context->CreateDeviceAddress(
nullptr, device_address->GetSize(), device_address->format(), device_address->type_id());
MS_EXCEPTION_IF_NULL(new_device_address);
new_device_address->set_host_shape(device_address->host_shape());
new_device_address->set_original_ref_count(device_address->original_ref_count());
new_device_address->ResetRefCount();
if (is_gradient_out) {