gpu fix addn bug and supported list bug

This commit is contained in:
VectorSL 2020-07-14 14:54:10 +08:00
parent 180b3029e5
commit d22a597689
2 changed files with 18 additions and 8 deletions

View File

@ -88,10 +88,11 @@ std::string SupportedTypeList(const CNodePtr &kernel_node) {
supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
}
supported_type_lists = supported_type_lists + supported_akg_type_list + "], out[";
supported_akg_type_list.clear();
for (auto type : supported_akg_type_out) {
supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
}
supported_type_lists += "]; ";
supported_type_lists = supported_type_lists + supported_akg_type_list + "]; ";
}
return supported_type_lists;
}

View File

@ -21,6 +21,8 @@
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/math/broadcast_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/slice_impl.cuh"
#include "kernel/gpu/kernel_constants.h"
namespace mindspore {
@ -43,18 +45,26 @@ class AddNGpuFwdKernel : public GpuKernel {
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *) override {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
}
const float alpha = 1;
const float beta = 0;
for (size_t i = 0; i < IntToSize(num_input_); i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr,
&(i > 0 ? alpha : beta), input_descriptor_, output_addr),
"cudnnAddTensor failed");
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr,
&(i > 0 ? alpha : beta), input_descriptor_, output_addr),
"cudnnAddTensor failed");
}
}
return true;
}
@ -100,9 +110,8 @@ class AddNGpuFwdKernel : public GpuKernel {
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast<size_t *>(&input_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed");
}
for (int i = 0; i < num_input_; i++) {
input_size_list_.push_back(input_size_);