!38830 add SetDeviceId api for kernel mod

Merge pull request !38830 from zhengyuanhua/master
This commit is contained in:
i-robot 2022-08-02 02:49:11 +00:00 committed by Gitee
commit 3a047c05e2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 2 additions and 2 deletions

View File

@ -288,6 +288,7 @@ class KernelMod {
const std::vector<AddressPtr> &GetOutputsAddr() const { return outputs_addr_; }
void set_stream(StreamType stream) { stream_ = stream; }
StreamType stream() const { return stream_; }
void SetDevicedId(uint32_t device_id) { device_id_ = device_id; }
virtual enum KernelModType GetKernelModType() const { return KernelModType::KernelMod; }
bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_,
@ -313,6 +314,7 @@ class KernelMod {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
bool is_need_retrieve_output_shape_ = false;
uint32_t device_id_;
private:
std::vector<AddressPtr> inputs_addr_;

View File

@ -138,7 +138,6 @@ class NativeGpuKernelMod : public GpuKernelMod {
}
return Factory<NativeGpuKernelMod>::Instance().Create(kernel_name)->GetAllSupportedList(kernel_name);
}
void SetDevicedId(uint32_t device_id) { device_id_ = device_id; }
static bool GpuCheckSupport(const std::string &kernel_name, const KernelAttr &kernel_attr);
static ReducePrecisonRes GpuReducePrecisionCheck(const std::string &kernel_name, const KernelAttr &kernel_attr) {
@ -149,7 +148,6 @@ class NativeGpuKernelMod : public GpuKernelMod {
protected:
virtual void InitResource() {}
uint32_t device_id_;
static mindspore::HashMap<std::string, std::vector<KernelAttr>> support_map_;
};