From 8d126cb95b349930c2aaddf0a796bb454a2cf869 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 13 Jul 2021 12:27:09 +0800 Subject: [PATCH] fix bug of output none and share weight --- .../backend/session/anf_runtime_algorithm.cc | 4 +++- .../ccsrc/runtime/device/device_address.h | 17 ----------------- .../ccsrc/runtime/framework/graph_scheduler.cc | 8 +++++++- mindspore/ccsrc/utils/convert_utils.cc | 8 +++++--- mindspore/core/ir/device_sync.h | 18 ++++++++++++++++++ mindspore/core/ir/tensor.h | 9 ++++++++- 6 files changed, 41 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 3bb5d8a43c5..4aece6b1f96 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -323,7 +323,9 @@ std::vector AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An if (node->isa()) { auto value = node->cast()->value(); MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { + if (value->isa()) { + return ret; + } else if (value->isa()) { auto value_tuple = value->cast(); auto value_tuple_size = CountValueNum(value_tuple); for (size_t i = 0; i < value_tuple_size; ++i) { diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h index 07e6e3e7737..d49f0e28922 100644 --- a/mindspore/ccsrc/runtime/device/device_address.h +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -82,20 +82,6 @@ class DeviceAddress : public mindspore::DeviceSync { virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } void *GetMutablePtr() const override { return ptr_; } virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; } - - // The related interface of reference count operation. - void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; } - size_t original_ref_count() const { return original_ref_count_; } - void set_ref_count(size_t ref_count) { ref_count_ = ref_count; } - size_t ref_count() const { return ref_count_; } - void IncreaseOriginalRefCount() { - if (original_ref_count_ < SIZE_MAX) { - original_ref_count_++; - } - } - void DecreaseRefCount() { ref_count_--; } - void ResetRefCount() { ref_count_ = original_ref_count_; } - virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape, TypeId host_type, bool trans_flag) const { return true; @@ -117,9 +103,6 @@ class DeviceAddress : public mindspore::DeviceSync { } mutable void *ptr_{nullptr}; size_t size_{0}; - mutable size_t original_ref_count_{1}; - // It will be decreased in the running, and reset by original_ref_count_ when it is zero. - mutable size_t ref_count_{1}; string format_{"DefaultFormat"}; TypeId type_id_{kNumberTypeFloat16}; mutable bool from_mem_pool_{false}; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index ef130a37c9e..0bd38df19c7 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -2019,6 +2019,9 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vectorGetAID().Name() + "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope(); + if (FetchActor(name) != nullptr) { + continue; + } auto copy_actor = std::make_shared(name, memory_manager_aid_); MS_EXCEPTION_IF_NULL(copy_actor); copy_actors_.emplace_back(copy_actor); @@ -3016,7 +3019,10 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n "; for (const auto &device_context : actor->device_contexts_) { - MS_EXCEPTION_IF_NULL(device_context); + if (device_context == nullptr) { + ofs << "\t\t\tdevice_context:" << device_context << "\n"; + continue; + } ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n"; } } diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 9567d9b941a..d516de73f8a 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -303,10 +303,12 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) { size_t cnt = 0; const auto &value_list = value_tuple->value(); for (const auto &value : value_list) { - if (!value->isa()) { - cnt++; - } else { + if (value->isa()) { + continue; + } else if (value->isa()) { cnt += CountValueNum(value->cast()); + } else { + cnt++; } } return cnt; diff --git a/mindspore/core/ir/device_sync.h b/mindspore/core/ir/device_sync.h index 42493060fc2..80b08e3f73d 100644 --- a/mindspore/core/ir/device_sync.h +++ b/mindspore/core/ir/device_sync.h @@ -42,6 +42,24 @@ class DeviceSync { virtual void *GetMutablePtr() const = 0; virtual void ClearDeviceMemory() = 0; + + // The related interface of reference count operation. + void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; } + size_t original_ref_count() const { return original_ref_count_; } + void set_ref_count(size_t ref_count) { ref_count_ = ref_count; } + size_t ref_count() const { return ref_count_; } + void IncreaseOriginalRefCount() { + if (original_ref_count_ < SIZE_MAX) { + original_ref_count_++; + } + } + void DecreaseRefCount() { ref_count_--; } + void ResetRefCount() { ref_count_ = original_ref_count_; } + + protected: + mutable size_t original_ref_count_{1}; + // It will be decreased in the running, and reset by original_ref_count_ when it is zero. + mutable size_t ref_count_{1}; }; using DeviceSyncPtr = std::shared_ptr; } // namespace mindspore diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 28520b02768..f6115567e2f 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -286,7 +286,14 @@ class Tensor : public MetaTensor { void set_init_flag(bool flag) { init_flag_ = flag; } DeviceSyncPtr device_address() const { return device_sync_; } - void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } + void set_device_address(const DeviceSyncPtr &device_sync) { + device_sync_ = device_sync; + // To support the old and new runtime coexistence. + if (device_sync_ != nullptr) { + device_sync_->set_original_ref_count(SIZE_MAX); + device_sync_->ResetRefCount(); + } + } void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; } std::string padding_type() const { return padding_type_; }