fix bug of output none and share weight
This commit is contained in:
parent
1cd02eab24
commit
8d126cb95b
|
@ -323,7 +323,9 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
if (node->isa<ValueNode>()) {
|
||||
auto value = node->cast<ValueNodePtr>()->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<ValueTuple>()) {
|
||||
if (value->isa<None>()) {
|
||||
return ret;
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
auto value_tuple_size = CountValueNum(value_tuple);
|
||||
for (size_t i = 0; i < value_tuple_size; ++i) {
|
||||
|
|
|
@ -82,20 +82,6 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
|
||||
void *GetMutablePtr() const override { return ptr_; }
|
||||
virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
|
||||
|
||||
// The related interface of reference count operation.
|
||||
void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
|
||||
size_t original_ref_count() const { return original_ref_count_; }
|
||||
void set_ref_count(size_t ref_count) { ref_count_ = ref_count; }
|
||||
size_t ref_count() const { return ref_count_; }
|
||||
void IncreaseOriginalRefCount() {
|
||||
if (original_ref_count_ < SIZE_MAX) {
|
||||
original_ref_count_++;
|
||||
}
|
||||
}
|
||||
void DecreaseRefCount() { ref_count_--; }
|
||||
void ResetRefCount() { ref_count_ = original_ref_count_; }
|
||||
|
||||
virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
|
||||
TypeId host_type, bool trans_flag) const {
|
||||
return true;
|
||||
|
@ -117,9 +103,6 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
}
|
||||
mutable void *ptr_{nullptr};
|
||||
size_t size_{0};
|
||||
mutable size_t original_ref_count_{1};
|
||||
// It will be decreased in the running, and reset by original_ref_count_ when it is zero.
|
||||
mutable size_t ref_count_{1};
|
||||
string format_{"DefaultFormat"};
|
||||
TypeId type_id_{kNumberTypeFloat16};
|
||||
mutable bool from_mem_pool_{false};
|
||||
|
|
|
@ -2019,6 +2019,9 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
// Create the copy actor.
|
||||
std::string name = "copy_from:" + kernel_actor->GetAID().Name() +
|
||||
"_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
|
||||
if (FetchActor(name) != nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(copy_actor);
|
||||
copy_actors_.emplace_back(copy_actor);
|
||||
|
@ -3016,7 +3019,10 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
|
|||
|
||||
ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
|
||||
for (const auto &device_context : actor->device_contexts_) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (device_context == nullptr) {
|
||||
ofs << "\t\t\tdevice_context:" << device_context << "\n";
|
||||
continue;
|
||||
}
|
||||
ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -303,10 +303,12 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
|
|||
size_t cnt = 0;
|
||||
const auto &value_list = value_tuple->value();
|
||||
for (const auto &value : value_list) {
|
||||
if (!value->isa<ValueTuple>()) {
|
||||
cnt++;
|
||||
} else {
|
||||
if (value->isa<None>()) {
|
||||
continue;
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
cnt += CountValueNum(value->cast<ValueTuplePtr>());
|
||||
} else {
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
return cnt;
|
||||
|
|
|
@ -42,6 +42,24 @@ class DeviceSync {
|
|||
|
||||
virtual void *GetMutablePtr() const = 0;
|
||||
virtual void ClearDeviceMemory() = 0;
|
||||
|
||||
// The related interface of reference count operation.
|
||||
void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
|
||||
size_t original_ref_count() const { return original_ref_count_; }
|
||||
void set_ref_count(size_t ref_count) { ref_count_ = ref_count; }
|
||||
size_t ref_count() const { return ref_count_; }
|
||||
void IncreaseOriginalRefCount() {
|
||||
if (original_ref_count_ < SIZE_MAX) {
|
||||
original_ref_count_++;
|
||||
}
|
||||
}
|
||||
void DecreaseRefCount() { ref_count_--; }
|
||||
void ResetRefCount() { ref_count_ = original_ref_count_; }
|
||||
|
||||
protected:
|
||||
mutable size_t original_ref_count_{1};
|
||||
// It will be decreased in the running, and reset by original_ref_count_ when it is zero.
|
||||
mutable size_t ref_count_{1};
|
||||
};
|
||||
using DeviceSyncPtr = std::shared_ptr<DeviceSync>;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -286,7 +286,14 @@ class Tensor : public MetaTensor {
|
|||
void set_init_flag(bool flag) { init_flag_ = flag; }
|
||||
|
||||
DeviceSyncPtr device_address() const { return device_sync_; }
|
||||
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
|
||||
void set_device_address(const DeviceSyncPtr &device_sync) {
|
||||
device_sync_ = device_sync;
|
||||
// To support the old and new runtime coexistence.
|
||||
if (device_sync_ != nullptr) {
|
||||
device_sync_->set_original_ref_count(SIZE_MAX);
|
||||
device_sync_->ResetRefCount();
|
||||
}
|
||||
}
|
||||
void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; }
|
||||
std::string padding_type() const { return padding_type_; }
|
||||
|
||||
|
|
Loading…
Reference in New Issue