forked from mindspore-Ecosystem/mindspore
!32024 fix transdata error for pynative
Merge pull request !32024 from chujinjin/fix_transdata_error_for_pynative
This commit is contained in:
commit
efae1fc3fa
|
@ -234,6 +234,7 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index)
|
|||
// Put device tensor into host tensor.
|
||||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
device_tensor->SetNodeIndex(output_node, output_index);
|
||||
tensor->set_device_address(device_tensor);
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHost);
|
||||
|
||||
|
@ -260,6 +261,8 @@ device::DeviceAddressPtr CloneEmptyDeviceAddress(const device::DeviceAddressPtr
|
|||
MS_EXCEPTION_IF_NULL(new_device_address);
|
||||
new_device_address->set_original_ref_count(old_device_address->original_ref_count());
|
||||
new_device_address->ResetRefCount();
|
||||
auto node = old_device_address->GetNodeIndex();
|
||||
new_device_address->SetNodeIndex(node.first, node.second);
|
||||
return new_device_address;
|
||||
}
|
||||
|
||||
|
|
|
@ -289,8 +289,13 @@ std::shared_ptr<LaunchKernel> AscendDeviceAddress::CreateLaunchTransData(const s
|
|||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto stream = runtime_instance->compute_stream();
|
||||
auto node = GetNodeIndex();
|
||||
int64_t groups = 1;
|
||||
if (format_ == kOpFormat_FRAC_Z && node.first != nullptr) {
|
||||
groups = common::AnfAlgo::GetAttrGroups(node.first, node.second);
|
||||
}
|
||||
auto launch_trans_data =
|
||||
std::make_shared<AscendLaunchTransData>(stream, type_id_, size_, ori_format, dst_format, host_shape);
|
||||
std::make_shared<AscendLaunchTransData>(stream, type_id_, size_, ori_format, dst_format, host_shape, groups);
|
||||
MS_EXCEPTION_IF_NULL(launch_trans_data);
|
||||
return launch_trans_data;
|
||||
}
|
||||
|
|
|
@ -114,6 +114,8 @@ void AscendLaunchTransData::ConstructKernelGraphAndSetAttr() {
|
|||
// set attr
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(src_format_), transdata_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(dst_format_), transdata_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups_), transdata_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups_), transdata_node);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::device::ascend
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace mindspore::device::ascend {
|
|||
class AscendLaunchTransData : public AscendLaunchKernel {
|
||||
public:
|
||||
AscendLaunchTransData(void *stream, TypeId dtype, size_t total_size, std::string src_format, std::string dst_format,
|
||||
std::vector<size_t> host_shape)
|
||||
std::vector<size_t> host_shape, int64_t groups)
|
||||
: AscendLaunchKernel(stream),
|
||||
dtype_(dtype),
|
||||
total_size_(total_size),
|
||||
|
@ -35,7 +35,8 @@ class AscendLaunchTransData : public AscendLaunchKernel {
|
|||
input_addr_(nullptr),
|
||||
src_format_(std::move(src_format)),
|
||||
dst_format_(std::move(dst_format)),
|
||||
shape_(std::move(host_shape)) {}
|
||||
shape_(std::move(host_shape)),
|
||||
groups_(groups) {}
|
||||
|
||||
~AscendLaunchTransData() override = default;
|
||||
|
||||
|
@ -57,6 +58,7 @@ class AscendLaunchTransData : public AscendLaunchKernel {
|
|||
std::string src_format_;
|
||||
std::string dst_format_;
|
||||
std::vector<size_t> shape_;
|
||||
int64_t groups_;
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::KernelGraph> ObtainTransDataKernelGraph();
|
||||
|
|
Loading…
Reference in New Issue