forked from mindspore-Ecosystem/mindspore
!17983 Fix dropout and matrix_inverse op bug.
Merge pull request !17983 from linqingke/gpu_ops
This commit is contained in:
commit
587063f92b
|
@ -39,17 +39,22 @@ class MatrixInverseGpuKernel : public GpuKernel {
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||||
auto lu_batch_addr = GetDeviceAddress<T *>(workspace, 0);
|
auto compute_input_addr = GetDeviceAddress<T>(workspace, 0);
|
||||||
auto inv_batch_addr = GetDeviceAddress<T *>(workspace, 1);
|
auto lu_batch_addr = GetDeviceAddress<T *>(workspace, 1);
|
||||||
auto pivo_addr = GetDeviceAddress<int>(workspace, 2);
|
auto inv_batch_addr = GetDeviceAddress<T *>(workspace, 2);
|
||||||
auto info_addr = GetDeviceAddress<int>(workspace, 3);
|
auto pivo_addr = GetDeviceAddress<int>(workspace, 3);
|
||||||
|
auto info_addr = GetDeviceAddress<int>(workspace, 4);
|
||||||
|
|
||||||
int len = SizeToInt(size_);
|
int len = SizeToInt(size_);
|
||||||
int batchsize = SizeToInt(batch_size_);
|
int batchsize = SizeToInt(batch_size_);
|
||||||
for (size_t i = 0; i < batch_size_; i++) {
|
for (size_t i = 0; i < batch_size_; i++) {
|
||||||
lu_addr_[i] = input_addr + i * len * len;
|
lu_addr_[i] = compute_input_addr + i * len * len;
|
||||||
inv_addr_[i] = output_addr + i * len * len;
|
inv_addr_[i] = output_addr + i * len * len;
|
||||||
}
|
}
|
||||||
|
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||||
|
cudaMemcpyAsync(compute_input_addr, input_addr, input_size_, cudaMemcpyDeviceToDevice,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cuda memcopy Fail");
|
||||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||||
cudaMemcpyAsync(lu_batch_addr, lu_addr_.data(), sizeof(T *) * batch_size_,
|
cudaMemcpyAsync(lu_batch_addr, lu_addr_.data(), sizeof(T *) * batch_size_,
|
||||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
@ -114,16 +119,17 @@ class MatrixInverseGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void InitSizeLists() override {
|
void InitSizeLists() override {
|
||||||
input_size_list_.push_back(input_size_);
|
input_size_list_.emplace_back(input_size_);
|
||||||
output_size_list_.push_back(input_size_);
|
output_size_list_.emplace_back(input_size_);
|
||||||
|
workspace_size_list_.emplace_back(input_size_);
|
||||||
size_t lu_size = batch_size_ * sizeof(T *);
|
size_t lu_size = batch_size_ * sizeof(T *);
|
||||||
workspace_size_list_.push_back(lu_size);
|
workspace_size_list_.emplace_back(lu_size);
|
||||||
size_t inv_size = batch_size_ * sizeof(T *);
|
size_t inv_size = batch_size_ * sizeof(T *);
|
||||||
workspace_size_list_.push_back(inv_size);
|
workspace_size_list_.emplace_back(inv_size);
|
||||||
size_t pivo_size = batch_size_ * size_ * sizeof(int);
|
size_t pivo_size = batch_size_ * size_ * sizeof(int);
|
||||||
workspace_size_list_.push_back(pivo_size);
|
workspace_size_list_.emplace_back(pivo_size);
|
||||||
size_t info_size = batch_size_ * sizeof(int);
|
size_t info_size = batch_size_ * sizeof(int);
|
||||||
workspace_size_list_.push_back(info_size);
|
workspace_size_list_.emplace_back(info_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -87,6 +87,9 @@ class DropoutGpuFwdKernel : public GpuKernel {
|
||||||
int64_t seed = GetAttr<int64_t>(kernel_node, "Seed0");
|
int64_t seed = GetAttr<int64_t>(kernel_node, "Seed0");
|
||||||
if (seed == 0) {
|
if (seed == 0) {
|
||||||
seed = GetAttr<int64_t>(kernel_node, "Seed1");
|
seed = GetAttr<int64_t>(kernel_node, "Seed1");
|
||||||
|
if (seed == 0) {
|
||||||
|
seed = time(NULL);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
seed_ = static_cast<uint64_t>(seed);
|
seed_ = static_cast<uint64_t>(seed);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue