!32024 fix transdata error for pynative

Merge pull request !32024 from chujinjin/fix_transdata_error_for_pynative
This commit is contained in:
i-robot 2022-03-28 02:54:51 +00:00 committed by Gitee
commit efae1fc3fa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 15 additions and 3 deletions

View File

@ -234,6 +234,7 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index)
// Put device tensor into host tensor.
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);
device_tensor->SetNodeIndex(output_node, output_index);
tensor->set_device_address(device_tensor);
tensor->set_sync_status(kNeedSyncDeviceToHost);
@ -260,6 +261,8 @@ device::DeviceAddressPtr CloneEmptyDeviceAddress(const device::DeviceAddressPtr
MS_EXCEPTION_IF_NULL(new_device_address);
new_device_address->set_original_ref_count(old_device_address->original_ref_count());
new_device_address->ResetRefCount();
auto node = old_device_address->GetNodeIndex();
new_device_address->SetNodeIndex(node.first, node.second);
return new_device_address;
}

View File

@ -289,8 +289,13 @@ std::shared_ptr<LaunchKernel> AscendDeviceAddress::CreateLaunchTransData(const s
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(runtime_instance);
auto stream = runtime_instance->compute_stream();
auto node = GetNodeIndex();
int64_t groups = 1;
if (format_ == kOpFormat_FRAC_Z && node.first != nullptr) {
groups = common::AnfAlgo::GetAttrGroups(node.first, node.second);
}
auto launch_trans_data =
std::make_shared<AscendLaunchTransData>(stream, type_id_, size_, ori_format, dst_format, host_shape);
std::make_shared<AscendLaunchTransData>(stream, type_id_, size_, ori_format, dst_format, host_shape, groups);
MS_EXCEPTION_IF_NULL(launch_trans_data);
return launch_trans_data;
}

View File

@ -114,6 +114,8 @@ void AscendLaunchTransData::ConstructKernelGraphAndSetAttr() {
// set attr
common::AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(src_format_), transdata_node);
common::AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(dst_format_), transdata_node);
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups_), transdata_node);
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups_), transdata_node);
}
}
} // namespace mindspore::device::ascend

View File

@ -27,7 +27,7 @@ namespace mindspore::device::ascend {
class AscendLaunchTransData : public AscendLaunchKernel {
public:
AscendLaunchTransData(void *stream, TypeId dtype, size_t total_size, std::string src_format, std::string dst_format,
std::vector<size_t> host_shape)
std::vector<size_t> host_shape, int64_t groups)
: AscendLaunchKernel(stream),
dtype_(dtype),
total_size_(total_size),
@ -35,7 +35,8 @@ class AscendLaunchTransData : public AscendLaunchKernel {
input_addr_(nullptr),
src_format_(std::move(src_format)),
dst_format_(std::move(dst_format)),
shape_(std::move(host_shape)) {}
shape_(std::move(host_shape)),
groups_(groups) {}
~AscendLaunchTransData() override = default;
@ -57,6 +58,7 @@ class AscendLaunchTransData : public AscendLaunchKernel {
std::string src_format_;
std::string dst_format_;
std::vector<size_t> shape_;
int64_t groups_;
private:
std::shared_ptr<session::KernelGraph> ObtainTransDataKernelGraph();