!20150 fix bug of output none

Merge pull request !20150 from limingqi107/r1.3
This commit is contained in:
zhangzhenghai 2021-07-14 01:02:58 +00:00 committed by Gitee
commit 46b4878b44
6 changed files with 41 additions and 23 deletions

View File

@ -323,7 +323,9 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
if (node->isa<ValueNode>()) {
auto value = node->cast<ValueNodePtr>()->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
if (value->isa<None>()) {
return ret;
} else if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
auto value_tuple_size = CountValueNum(value_tuple);
for (size_t i = 0; i < value_tuple_size; ++i) {

View File

@ -82,20 +82,6 @@ class DeviceAddress : public mindspore::DeviceSync {
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
void *GetMutablePtr() const override { return ptr_; }
virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
// The related interface of reference count operation.
void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
size_t original_ref_count() const { return original_ref_count_; }
void set_ref_count(size_t ref_count) { ref_count_ = ref_count; }
size_t ref_count() const { return ref_count_; }
void IncreaseOriginalRefCount() {
if (original_ref_count_ < SIZE_MAX) {
original_ref_count_++;
}
}
void DecreaseRefCount() { ref_count_--; }
void ResetRefCount() { ref_count_ = original_ref_count_; }
virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
TypeId host_type, bool trans_flag) const {
return true;
@ -117,9 +103,6 @@ class DeviceAddress : public mindspore::DeviceSync {
}
mutable void *ptr_{nullptr};
size_t size_{0};
mutable size_t original_ref_count_{1};
// It will be decreased in the running, and reset by original_ref_count_ when it is zero.
mutable size_t ref_count_{1};
string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16};
mutable bool from_mem_pool_{false};

View File

@ -2019,6 +2019,9 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
// Create the copy actor.
std::string name = "copy_from:" + kernel_actor->GetAID().Name() +
"_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
if (FetchActor(name) != nullptr) {
continue;
}
auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(copy_actor);
copy_actors_.emplace_back(copy_actor);
@ -3016,7 +3019,10 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
for (const auto &device_context : actor->device_contexts_) {
MS_EXCEPTION_IF_NULL(device_context);
if (device_context == nullptr) {
ofs << "\t\t\tdevice_context:" << device_context << "\n";
continue;
}
ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
}
}

View File

@ -303,10 +303,12 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
size_t cnt = 0;
const auto &value_list = value_tuple->value();
for (const auto &value : value_list) {
if (!value->isa<ValueTuple>()) {
cnt++;
} else {
if (value->isa<None>()) {
continue;
} else if (value->isa<ValueTuple>()) {
cnt += CountValueNum(value->cast<ValueTuplePtr>());
} else {
cnt++;
}
}
return cnt;

View File

@ -42,6 +42,24 @@ class DeviceSync {
virtual void *GetMutablePtr() const = 0;
virtual void ClearDeviceMemory() = 0;
// The related interface of reference count operation.
void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
size_t original_ref_count() const { return original_ref_count_; }
void set_ref_count(size_t ref_count) { ref_count_ = ref_count; }
size_t ref_count() const { return ref_count_; }
void IncreaseOriginalRefCount() {
if (original_ref_count_ < SIZE_MAX) {
original_ref_count_++;
}
}
void DecreaseRefCount() { ref_count_--; }
void ResetRefCount() { ref_count_ = original_ref_count_; }
protected:
mutable size_t original_ref_count_{1};
// It will be decreased in the running, and reset by original_ref_count_ when it is zero.
mutable size_t ref_count_{1};
};
using DeviceSyncPtr = std::shared_ptr<DeviceSync>;
} // namespace mindspore

View File

@ -286,7 +286,14 @@ class Tensor : public MetaTensor {
void set_init_flag(bool flag) { init_flag_ = flag; }
DeviceSyncPtr device_address() const { return device_sync_; }
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
void set_device_address(const DeviceSyncPtr &device_sync) {
device_sync_ = device_sync;
// To support the old and new runtime coexistence.
if (device_sync_ != nullptr) {
device_sync_->set_original_ref_count(SIZE_MAX);
device_sync_->ResetRefCount();
}
}
void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; }
std::string padding_type() const { return padding_type_; }