forked from mindspore-Ecosystem/mindspore
!7143 GPU addn support same out/in
Merge pull request !7143 from VectorSL/addn-fix
This commit is contained in:
commit
97087fca5b
|
@ -44,28 +44,41 @@ class AddNGpuFwdKernel : public GpuKernel {
|
|||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
auto work_addr = output_addr;
|
||||
for (size_t i = 0; i < IntToSize(num_input_); i++) {
|
||||
if (output_addr == GetDeviceAddress<T>(inputs, i)) {
|
||||
work_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
|
||||
FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
FillDeviceArray(outputs[0]->size / sizeof(T), work_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
for (size_t i = 0; i < IntToSize(num_input_); i++) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, i);
|
||||
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
|
||||
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr,
|
||||
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr,
|
||||
&(i > 0 ? alpha : beta), input_descriptor_, output_addr),
|
||||
&(i > 0 ? alpha : beta), input_descriptor_, work_addr),
|
||||
"cudnnAddTensor failed");
|
||||
}
|
||||
}
|
||||
if (work_addr != output_addr) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, work_addr, outputs[0]->size, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Addn cudaMemcpyAsync outputs failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -124,6 +137,7 @@ class AddNGpuFwdKernel : public GpuKernel {
|
|||
input_size_list_.push_back(input_size_);
|
||||
}
|
||||
output_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue