remove synchronous error check

This commit is contained in:
tom__chen 2021-05-18 15:40:25 -04:00
parent f56079d67b
commit 799cb79873
4 changed files with 12 additions and 75 deletions

View File

@ -13,27 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
__global__ void InitErrorCode(IndexAddErrorCode *error_code) {
*error_code = IndexAddErrorCode::kOk;
}
__global__ void ValidateIndexValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size,
IndexAddErrorCode *error_code) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_axis_size; pos += blockDim.x * gridDim.x) {
const int idx_value = index[pos];
if (idx_value < 0 || idx_value >= dst_axis_size) {
*error_code = IndexAddErrorCode::kIndexOutOfRange;
return;
}
}
return;
}
template <typename T>
__global__ void IndexAddAtomic(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) {
@ -41,9 +25,11 @@ __global__ void IndexAddAtomic(T *dst, const int *index, const T *src, const siz
const size_t src_axis_idx = (pos / inner_size) % src_axis_size;
const size_t src_outer_idx = pos / (src_axis_size * inner_size);
const size_t dst_axis_idx = static_cast<size_t>(index[src_axis_idx]);
const size_t dst_inner_idx = pos % inner_size;
const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx;
MsAtomicAdd(&dst[dst_idx], src[pos]);
if (dst_axis_idx >= 0 && dst_axis_idx < dst_axis_size) {
const size_t dst_inner_idx = pos % inner_size;
const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx;
MsAtomicAdd(&dst[dst_idx], src[pos]);
}
}
return;
}
@ -55,20 +41,15 @@ __global__ void IndexAdd(T *dst, const int *index, const T *src, const size_t sr
const size_t src_axis_idx = (pos / inner_size) % src_axis_size;
const size_t src_outer_idx = pos / (src_axis_size * inner_size);
const size_t dst_axis_idx = static_cast<size_t>(index[src_axis_idx]);
const size_t dst_inner_idx = pos % inner_size;
const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx;
dst[dst_idx] += src[pos];
if (dst_axis_idx >= 0 && dst_axis_idx < dst_axis_size) {
const size_t dst_inner_idx = pos % inner_size;
const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx;
dst[dst_idx] += src[pos];
}
}
return;
}
void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size,
IndexAddErrorCode *error_code, cudaStream_t cuda_stream) {
InitErrorCode<<<1, 1, 0, cuda_stream>>>(error_code);
ValidateIndexValues<<<GET_BLOCKS(src_axis_size), GET_THREADS, 0, cuda_stream>>>(index, src_axis_size, dst_axis_size,
error_code);
}
template <typename T>
void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size,
const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream) {

View File

@ -16,16 +16,7 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_
enum class IndexAddErrorCode {
kOk = 0,
kIndexOutOfRange
};
void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size,
IndexAddErrorCode *error_code, cudaStream_t cuda_stream);
template <typename T>
void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size,
const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_

View File

@ -35,8 +35,7 @@ class IndexAddGpuKernel : public GpuKernel {
src_axis_size_(0),
dst_axis_size_(0),
inner_size_(0),
use_lock_(true),
check_index_bound_(true) {}
use_lock_(true) {}
~IndexAddGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -49,19 +48,6 @@ class IndexAddGpuKernel : public GpuKernel {
int *index = GetDeviceAddress<int>(inputs, 1);
T *src = GetDeviceAddress<T>(inputs, 2);
T *dst_out = GetDeviceAddress<T>(outputs, 0);
if (check_index_bound_) {
IndexAddErrorCode *error_code_addr = GetDeviceAddress<IndexAddErrorCode>(workspace, 0);
IndexAddErrorCode error_code = IndexAddErrorCode::kOk;
ValidateIndexAddInputValues(index, src_axis_size_, dst_axis_size_, error_code_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(&error_code, error_code_addr, sizeof(IndexAddErrorCode),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy error code to host.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
LogExceptionIfNotOk(error_code);
}
CalIndexAdd(dst, index, src, outer_size_, src_axis_size_, dst_axis_size_, inner_size_, use_lock_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
@ -119,22 +105,9 @@ class IndexAddGpuKernel : public GpuKernel {
input_size_list_.push_back(index_size_);
input_size_list_.push_back(src_size_);
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(sizeof(IndexAddErrorCode));
}
private:
void LogExceptionIfNotOk(IndexAddErrorCode error_code) {
switch (error_code) {
case IndexAddErrorCode::kOk:
return;
case IndexAddErrorCode::kIndexOutOfRange:
MS_LOG(EXCEPTION) << "gpu IndexAdd op error: values of index tensor is out of range";
break;
default:
MS_LOG(EXCEPTION) << "gpu IndexAdd op unknown error";
}
}
size_t dst_size_;
size_t index_size_;
size_t src_size_;
@ -144,7 +117,6 @@ class IndexAddGpuKernel : public GpuKernel {
size_t dst_axis_size_;
size_t inner_size_;
bool use_lock_;
bool check_index_bound_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -255,13 +255,6 @@ def test_index_add_invalid_inputs():
net = NetIndexAdd(x, 1)
_ = net(Tensor(idx), Tensor(y))
with pytest.raises(RuntimeError) as info:
#index value not in the range of 0 to len(x[axis])
idx = np.array([5, 6]).astype(np.int32)
net = NetIndexAdd(x, 1)
_ = net(Tensor(idx), Tensor(y))
assert "out of range" in str(info.value)
class IndexAddGradNet(nn.Cell):
def __init__(self, network):