forked from mindspore-Ecosystem/mindspore
!27700 Add host shape to DeviceAddress
Merge pull request !27700 from caifubi/master-pynative-mindrt-bugfix-host-shape
This commit is contained in:
commit
7caed8835f
|
@ -101,6 +101,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
bool is_ptr_persisted() const { return is_ptr_persisted_; }
|
||||
void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
|
||||
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
|
||||
ShapeVector host_shape() const { return host_shape_; }
|
||||
bool from_persistent_mem() const { return from_persistent_mem_; }
|
||||
void set_from_persistent_mem(bool from_persistent_mem) { from_persistent_mem_ = from_persistent_mem; }
|
||||
virtual void set_status(DeviceAddressStatus status) {}
|
||||
|
|
|
@ -195,6 +195,7 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
|
|||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
|
||||
auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type);
|
||||
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
|
||||
if (is_gradient_out) {
|
||||
device_address->set_from_persistent_mem(true);
|
||||
}
|
||||
|
|
|
@ -250,6 +250,7 @@ void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *d
|
|||
auto new_device_address = device_context->CreateDeviceAddress(
|
||||
nullptr, device_address->GetSize(), device_address->format(), device_address->type_id());
|
||||
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||
new_device_address->set_host_shape(device_address->host_shape());
|
||||
new_device_address->set_original_ref_count(device_address->original_ref_count());
|
||||
new_device_address->ResetRefCount();
|
||||
if (is_gradient_out) {
|
||||
|
|
Loading…
Reference in New Issue