!38830 add SetDeviceId api for kernel mod
Merge pull request !38830 from zhengyuanhua/master
This commit is contained in:
commit
3a047c05e2
|
@ -288,6 +288,7 @@ class KernelMod {
|
|||
const std::vector<AddressPtr> &GetOutputsAddr() const { return outputs_addr_; }
|
||||
void set_stream(StreamType stream) { stream_ = stream; }
|
||||
StreamType stream() const { return stream_; }
|
||||
void SetDevicedId(uint32_t device_id) { device_id_ = device_id; }
|
||||
virtual enum KernelModType GetKernelModType() const { return KernelModType::KernelMod; }
|
||||
bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
|
||||
return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_,
|
||||
|
@ -313,6 +314,7 @@ class KernelMod {
|
|||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
bool is_need_retrieve_output_shape_ = false;
|
||||
uint32_t device_id_;
|
||||
|
||||
private:
|
||||
std::vector<AddressPtr> inputs_addr_;
|
||||
|
|
|
@ -138,7 +138,6 @@ class NativeGpuKernelMod : public GpuKernelMod {
|
|||
}
|
||||
return Factory<NativeGpuKernelMod>::Instance().Create(kernel_name)->GetAllSupportedList(kernel_name);
|
||||
}
|
||||
void SetDevicedId(uint32_t device_id) { device_id_ = device_id; }
|
||||
static bool GpuCheckSupport(const std::string &kernel_name, const KernelAttr &kernel_attr);
|
||||
|
||||
static ReducePrecisonRes GpuReducePrecisionCheck(const std::string &kernel_name, const KernelAttr &kernel_attr) {
|
||||
|
@ -149,7 +148,6 @@ class NativeGpuKernelMod : public GpuKernelMod {
|
|||
|
||||
protected:
|
||||
virtual void InitResource() {}
|
||||
uint32_t device_id_;
|
||||
static mindspore::HashMap<std::string, std::vector<KernelAttr>> support_map_;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue