add gpu buffer kernel

This commit is contained in:
VectorSL 2021-08-09 10:02:39 +08:00 committed by vector
parent c84e09cf45
commit 6b4f02af58
18 changed files with 864 additions and 36 deletions

View File

@ -967,5 +967,39 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
} }
return offset; return offset;
} }
size_t UnitSizeInBytes(const mindspore::TypeId &t) {
size_t bytes = 0;
switch (t) {
case kNumberTypeBool:
case kNumberTypeInt8:
case kNumberTypeUInt8:
bytes = sizeof(int8_t);
break;
case kNumberTypeInt16:
case kNumberTypeUInt16:
case kNumberTypeFloat16:
bytes = sizeof(int16_t);
break;
case kNumberTypeInt:
case kNumberTypeUInt:
case kNumberTypeInt32:
case kNumberTypeUInt32:
case kNumberTypeFloat:
case kNumberTypeFloat32:
bytes = sizeof(int32_t);
break;
case kNumberTypeUInt64:
case kNumberTypeInt64:
case kNumberTypeFloat64:
bytes = sizeof(int64_t);
break;
default:
MS_LOG(EXCEPTION) << "Invalid types " << t;
break;
}
return bytes;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -143,6 +143,7 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape); std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape);
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start, size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
const std::vector<int64_t> &stop); const std::vector<int64_t> &stop);
size_t UnitSizeInBytes(const mindspore::TypeId &t);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,121 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
__global__ void BufferAppendKernel(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
unsigned char *buffer, const unsigned char *exp) {
size_t index_ = index[0];
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (i >= size / exp_batch * (capacity - index[0])) {
index_ = i - size / exp_batch * (capacity - index[0]);
} else {
index_ = i + index[0] * size / exp_batch;
}
buffer[index_] = exp[i];
}
}
__global__ void IncreaseCountKernel(const int64_t capacity, const int exp_batch, int *count, int *head, int *index) {
int index_ = 0;
if (count[0] <= capacity - 1 && head[0] == 0) {
index_ = count[0];
count[0] += exp_batch;
if (count[0] > capacity) {
count[0] = capacity;
head[0] = (exp_batch + count[0] - capacity) % capacity;
}
} else {
index_ = head[0];
if (head[0] == count[0])
head[0] = 0;
else
head[0] = (exp_batch + head[0]) % capacity;
}
index[0] = index_;
}
__global__ void ReMappingIndexKernel(const int *count, const int *head, const int *origin_index, int *index) {
index[0] = origin_index[0];
if (index[0] < 0) {
index[0] += count[0];
}
if (!(index[0] >= 0 && index[0] < count[0])) {
printf("[ERROR] The index %d is out of range:[%d, %d).", origin_index[0], -1 * count[0], count[0]);
index[0] = -1;
return;
}
int t = count[0] - head[0];
if (index[0] < t) {
index[0] += head[0];
} else {
index[0] -= t;
}
}
__global__ void BufferGetItemKernel(const size_t size, const int *index, const size_t one_exp_len,
const unsigned char *buffer, unsigned char *out) {
if (index[0] == -1) {
return;
}
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
out[i] = buffer[i + index[0] * one_exp_len];
}
}
__global__ void CheckBatchSizeKernel(const int *count, const int *head, const size_t batch_size,
const int64_t capacity) {
if ((head[0] > 0 && int64_t(batch_size) > capacity) || (head[0] == 0 && batch_size > size_t(count[0]))) {
printf("[ERROR] The batch size %d is larger than total buffer size %d", static_cast<int>(batch_size),
(capacity > static_cast<int64_t>(count[0]) ? static_cast<int>(count[0]) : static_cast<int>(capacity)));
}
}
__global__ void BufferSampleKernel(const size_t size, const size_t one_element, const int *index,
const unsigned char *buffer, unsigned char *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
out[i] = buffer[index[i / one_element] * one_element + i % one_element];
}
}
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream) {
BufferAppendKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(capacity, size, index, exp_batch, buffer, exp);
}
void IncreaseCount(const int64_t capacity, const int exp_batch, int *count, int *head, int *index,
cudaStream_t cuda_stream) {
IncreaseCountKernel<<<1, 1, 0, cuda_stream>>>(capacity, exp_batch, count, head, index);
}
void ReMappingIndex(const int *count, const int *head, const int *origin_index, int *index, cudaStream_t cuda_stream) {
ReMappingIndexKernel<<<1, 1, 0, cuda_stream>>>(count, head, origin_index, index);
}
void BufferGetItem(const size_t size, const int *index, const size_t one_exp_len, const unsigned char *buffer,
unsigned char *out, cudaStream_t cuda_stream) {
BufferGetItemKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, index, one_exp_len, buffer, out);
}
void CheckBatchSize(const int *count, const int *head, const size_t batch_size, const int64_t capacity,
cudaStream_t cuda_stream) {
CheckBatchSizeKernel<<<1, 1, 0, cuda_stream>>>(count, head, batch_size, capacity);
}
void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
unsigned char *out, cudaStream_t cuda_stream) {
BufferSampleKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, one_element, index, buffer, out);
}

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream);
void IncreaseCount(const int64_t capacity, const int exp_batch, int *count, int *head, int *index,
cudaStream_t cuda_stream);
void ReMappingIndex(const int *count, const int *head, const int *origin_index, int *index, cudaStream_t cuda_stream);
void BufferGetItem(const size_t size, const int *index, const size_t one_exp_len, const unsigned char *buffer,
unsigned char *out, cudaStream_t cuda_stream);
void CheckBatchSize(const int *count, const int *head, const size_t batch_size, const int64_t capacity,
cudaStream_t cuda_stream);
void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
unsigned char *out, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_

View File

@ -17,6 +17,7 @@
#include "backend/kernel_compiler/gpu/data/dataset_init_kernel.h" #include "backend/kernel_compiler/gpu/data/dataset_init_kernel.h"
#include <algorithm> #include <algorithm>
#include "backend/kernel_compiler/gpu/data/dataset_utils.h" #include "backend/kernel_compiler/gpu/data/dataset_utils.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/gpu/gpu_buffer_mgr.h" #include "runtime/device/gpu/gpu_buffer_mgr.h"
#include "runtime/device/gpu/gpu_memory_allocator.h" #include "runtime/device/gpu/gpu_memory_allocator.h"
#include "utils/convert_utils.h" #include "utils/convert_utils.h"

View File

@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "backend/kernel_compiler/gpu/data/dataset_utils.h" #include "backend/kernel_compiler/gpu/data/dataset_utils.h"
#include "backend/kernel_compiler/common_utils.h"
#include "profiler/device/gpu/gpu_profiling.h" #include "profiler/device/gpu/gpu_profiling.h"
#include "runtime/device/gpu/gpu_buffer_mgr.h" #include "runtime/device/gpu/gpu_buffer_mgr.h"
#include "runtime/device/gpu/gpu_common.h" #include "runtime/device/gpu/gpu_common.h"

View File

@ -20,40 +20,6 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
size_t UnitSizeInBytes(const mindspore::TypeId &t) {
size_t bytes = 0;
switch (t) {
case kNumberTypeBool:
case kNumberTypeInt8:
case kNumberTypeUInt8:
bytes = sizeof(int8_t);
break;
case kNumberTypeInt16:
case kNumberTypeUInt16:
case kNumberTypeFloat16:
bytes = sizeof(int16_t);
break;
case kNumberTypeInt:
case kNumberTypeUInt:
case kNumberTypeInt32:
case kNumberTypeUInt32:
case kNumberTypeFloat:
case kNumberTypeFloat32:
bytes = sizeof(int32_t);
break;
case kNumberTypeUInt64:
case kNumberTypeInt64:
case kNumberTypeFloat64:
bytes = sizeof(int64_t);
break;
default:
MS_LOG(EXCEPTION) << "Invalid types " << t;
break;
}
return bytes;
}
int ElementNums(const std::vector<int> &shape) { int ElementNums(const std::vector<int> &shape) {
if (shape.size() == 0) { if (shape.size() == 0) {
return 0; return 0;

View File

@ -21,7 +21,6 @@
#include "ir/dtype/type.h" #include "ir/dtype/type.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
size_t UnitSizeInBytes(const mindspore::TypeId &t);
int ElementNums(const std::vector<int> &shape); int ElementNums(const std::vector<int> &shape);
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types); void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types);
} // namespace kernel } // namespace kernel

View File

@ -27,7 +27,7 @@
#include "ir/tensor.h" #include "ir/tensor.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/data/dataset_utils.h" #include "backend/kernel_compiler/common_utils.h"
using mindspore::tensor::Tensor; using mindspore::tensor::Tensor;

View File

@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/rl/buffer_append_gpu_kernel.h"
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
#include "runtime/device/gpu/gpu_common.h"
namespace mindspore {
namespace kernel {
BufferAppendKernel::BufferAppendKernel() : element_nums_(0), exp_batch_(0), capacity_(0) {}
BufferAppendKernel::~BufferAppendKernel() {}
void BufferAppendKernel::ReleaseResource() {}
const std::vector<size_t> &BufferAppendKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &BufferAppendKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &BufferAppendKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool BufferAppendKernel::Init(const CNodePtr &kernel_node) {
kernel_node_ = kernel_node;
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
exp_batch_ = GetAttr<int64_t>(kernel_node, "exp_batch");
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id()));
}
// buffer size
for (auto i : exp_element_list) {
input_size_list_.push_back(i * capacity_);
}
// exp size
for (auto i : exp_element_list) {
input_size_list_.push_back(i * exp_batch_);
}
// count and head
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
output_size_list_.push_back(0);
workspace_size_list_.push_back(sizeof(int));
return true;
}
void BufferAppendKernel::InitSizeLists() { return; }
bool BufferAppendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &, void *stream) {
int *count_addr = GetDeviceAddress<int>(inputs, 2 * element_nums_);
int *head_addr = GetDeviceAddress<int>(inputs, 2 * element_nums_ + 1);
int *index_addr = GetDeviceAddress<int>(workspace, 0);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
IncreaseCount(capacity_, LongToInt(exp_batch_), count_addr, head_addr, index_addr, cuda_stream);
for (size_t i = 0; i < element_nums_; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto exp_addr = GetDeviceAddress<unsigned char>(inputs, i + element_nums_);
size_t one_exp_len = input_size_list_[i + element_nums_];
BufferAppend(capacity_, one_exp_len, index_addr, LongToInt(exp_batch_), buffer_addr, exp_addr, cuda_stream);
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_APPEND_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_APPEND_GPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class BufferAppendKernel : public GpuKernel {
public:
BufferAppendKernel();
~BufferAppendKernel();
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
void ReleaseResource() override;
protected:
void InitSizeLists() override;
private:
size_t element_nums_;
int64_t exp_batch_;
int64_t capacity_;
std::vector<size_t> exp_element_list;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_GPU_KERNEL(BufferAppend, BufferAppendKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_APPEND_GPU_KERNEL_H_

View File

@ -0,0 +1,87 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/rl/buffer_get_gpu_kernel.h"
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
#include "runtime/device/gpu/gpu_common.h"
namespace mindspore {
namespace kernel {
BufferGetKernel::BufferGetKernel() : element_nums_(0), capacity_(0) {}
BufferGetKernel::~BufferGetKernel() {}
void BufferGetKernel::ReleaseResource() {}
const std::vector<size_t> &BufferGetKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &BufferGetKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &BufferGetKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool BufferGetKernel::Init(const CNodePtr &kernel_node) {
kernel_node_ = kernel_node;
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id()));
}
// buffer size
for (auto i : exp_element_list) {
input_size_list_.push_back(i * capacity_);
output_size_list_.push_back(i);
}
// count, head, index
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
workspace_size_list_.push_back(sizeof(int));
return true;
}
void BufferGetKernel::InitSizeLists() { return; }
bool BufferGetKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream) {
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_);
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
int *origin_index_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
int *index_addr = GetDeviceAddress<int>(workspace, 0);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
ReMappingIndex(count_addr, head_addr, origin_index_addr, index_addr, cuda_stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) {
return false;
}
for (size_t i = 0; i < element_nums_; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto item_addr = GetDeviceAddress<unsigned char>(outputs, i);
size_t one_exp_len = output_size_list_[i];
BufferGetItem(one_exp_len, index_addr, one_exp_len, buffer_addr, item_addr, cuda_stream);
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_GET_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_GET_GPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class BufferGetKernel : public GpuKernel {
public:
BufferGetKernel();
~BufferGetKernel();
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
void ReleaseResource() override;
protected:
void InitSizeLists() override;
private:
size_t element_nums_;
int64_t capacity_;
std::vector<size_t> exp_element_list;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_GPU_KERNEL(BufferGetItem, BufferGetKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_GET_GPU_KERNEL_H_

View File

@ -0,0 +1,81 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h"
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
#include "runtime/device/gpu/gpu_common.h"
namespace mindspore {
namespace kernel {
BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0) {}
BufferSampleKernel::~BufferSampleKernel() {}
void BufferSampleKernel::ReleaseResource() {}
const std::vector<size_t> &BufferSampleKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &BufferSampleKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &BufferSampleKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
kernel_node_ = kernel_node;
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
batch_size_ = LongToSize(GetAttr<int64_t>(kernel_node, "batch_size"));
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
auto element = shapes[i] * UnitSizeInBytes(types[i]->type_id());
exp_element_list.push_back(element);
input_size_list_.push_back(capacity_ * element);
output_size_list_.push_back(batch_size_ * element);
}
// index
input_size_list_.push_back(sizeof(int) * batch_size_);
// count and head
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
return true;
}
void BufferSampleKernel::InitSizeLists() { return; }
bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream) {
int *index_addr = GetDeviceAddress<int>(inputs, element_nums_);
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
CheckBatchSize(count_addr, head_addr, batch_size_, capacity_, cuda_stream);
for (size_t i = 0; i < element_nums_; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto out_addr = GetDeviceAddress<unsigned char>(outputs, i);
size_t size = batch_size_ * exp_element_list[i];
BufferSample(size, exp_element_list[i], index_addr, buffer_addr, out_addr, cuda_stream);
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class BufferSampleKernel : public GpuKernel {
public:
BufferSampleKernel();
~BufferSampleKernel();
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
void ReleaseResource() override;
protected:
void InitSizeLists() override;
private:
size_t element_nums_;
int64_t capacity_;
size_t batch_size_;
std::vector<size_t> exp_element_list;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_GPU_KERNEL(BufferSample, BufferSampleKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_

View File

@ -17,6 +17,7 @@
#include "backend/kernel_compiler/gpu/data/dataset_utils.h" #include "backend/kernel_compiler/gpu/data/dataset_utils.h"
#include "backend/kernel_compiler/gpu/trt/trt_utils.h" #include "backend/kernel_compiler/gpu/trt/trt_utils.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/gpu/trt_loader.h" #include "runtime/device/gpu/trt_loader.h"
namespace mindspore { namespace mindspore {

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore as ms
def create_tensor(capcity, shapes, dtypes):
buffer = []
for i in range(len(shapes)):
buffer.append(Tensor(np.zeros(((capcity,)+shapes[i])), dtypes[i]))
return buffer
class RLBuffer(nn.Cell):
def __init__(self, batch_size, capcity, shapes, types):
super(RLBuffer, self).__init__()
self.buffer = create_tensor(capcity, shapes, types)
self._capacity = capcity
self.count = Parameter(Tensor(0, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.buffer_append = P.BufferAppend(self._capacity, shapes, types)
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
self.buffer_sample = P.BufferSample(
self._capacity, batch_size, shapes, types)
self.randperm = P.Randperm(max_length=capcity, pad=-1)
self.reshape = P.Reshape()
@ms_function
def append(self, exps):
return self.buffer_append(self.buffer, exps, self.count, self.head)
@ms_function
def get(self, index):
return self.buffer_get(self.buffer, self.count, self.head, index)
@ms_function
def sample(self):
count = self.reshape(self.count, (1,))
index = self.randperm(count)
return self.buffer_sample(self.buffer, index, self.count, self.head)
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
a = Tensor(np.array([0, 1]), ms.int32)
r = Tensor(np.array([1]), ms.float32)
s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32)
exp = [s, a, r, s_]
exp1 = [s_, a, r, s]
@ pytest.mark.level0
@ pytest.mark.platform_x86_gpu_training
@ pytest.mark.env_onecard
def test_Buffer():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
buffer = RLBuffer(batch_size=32, capcity=100, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.float32, ms.float32])
print("init buffer:\n", buffer.buffer)
for _ in range(0, 110):
buffer.append(exp)
buffer.append(exp1)
print("buffer append:\n", buffer.buffer)
b = buffer.get(-1)
print("buffer get:\n", b)
bs = buffer.sample()
print("buffer sample:\n", bs)

View File

@ -0,0 +1,157 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore as ms
class RLBufferAppend(nn.Cell):
def __init__(self, capcity, shapes, types):
super(RLBufferAppend, self).__init__()
self._capacity = capcity
self.count = Parameter(Tensor(0, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.buffer_append = P.BufferAppend(self._capacity, shapes, types)
@ms_function
def construct(self, buffer, exps):
return self.buffer_append(buffer, exps, self.count, self.head)
class RLBufferGet(nn.Cell):
def __init__(self, capcity, shapes, types):
super(RLBufferGet, self).__init__()
self._capacity = capcity
self.count = Parameter(Tensor(5, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
@ms_function
def construct(self, buffer, index):
return self.buffer_get(buffer, self.count, self.head, index)
class RLBufferSample(nn.Cell):
def __init__(self, capcity, batch_size, shapes, types):
super(RLBufferSample, self).__init__()
self._capacity = capcity
count = 5
self.count = Parameter(Tensor(5, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.input_x = Tensor(np.ones(shape=[count]), ms.bool_)
self.buffer_sample = P.BufferSample(
self._capacity, batch_size, shapes, types)
self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index")
@ms_function
def construct(self, buffer):
return self.buffer_sample(buffer, self.index, self.count, self.head)
states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0)
actions = Tensor(np.arange(2*5).reshape(5, 2).astype(np.int32))
rewards = Tensor(np.ones((5, 1)).astype(np.int32))
states_ = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32))
b = [states, actions, rewards, states_]
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
a = Tensor(np.array([0, 0]), ms.int32)
r = Tensor(np.array([0]), ms.int32)
s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32)
exp = [s, a, r, s_]
exp1 = [s_, a, r, s]
c = [Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32),
Tensor(np.array([[6, 6], [6, 6]]), ms.int32),
Tensor(np.array([[6], [6]]), ms.int32),
Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32)]
@ pytest.mark.level0
@ pytest.mark.platform_x86_gpu_training
@ pytest.mark.env_onecard
def test_BufferSample():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
ss, aa, rr, ss_ = buffer_sample(b)
expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]]
expect_a = [[0, 1], [4, 5], [8, 9]]
expect_r = [[1], [1], [1]]
expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]]
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
@ pytest.mark.level0
@ pytest.mark.platform_x86_gpu_training
@ pytest.mark.env_onecard
def test_BufferGet():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
buffer_get = RLBufferGet(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
ss, aa, rr, ss_ = buffer_get(b, 1)
expect_s = [0.4, 0.5, 0.6, 0.7]
expect_a = [2, 3]
expect_r = [1]
expect_s_ = [4, 5, 6, 7]
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
@ pytest.mark.level0
@ pytest.mark.platform_x86_gpu_training
@ pytest.mark.env_onecard
def test_BufferAppend():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
buffer_append = RLBufferAppend(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp1)
expect_s = [[3, 3, 3, 3], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]
expect_a = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]
expect_r = [[0], [0], [0], [0], [0]]
expect_s_ = [[2, 2, 2, 2], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3]]
np.testing.assert_almost_equal(b[0].asnumpy(), expect_s)
np.testing.assert_almost_equal(b[1].asnumpy(), expect_a)
np.testing.assert_almost_equal(b[2].asnumpy(), expect_r)
np.testing.assert_almost_equal(b[3].asnumpy(), expect_s_)
buffer_append(b, exp1)
buffer_append(b, c)
buffer_append(b, c)
expect_s2 = [[6, 6, 6, 6], [3, 3, 3, 3], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]]
expect_a2 = [[6, 6], [0, 0], [6, 6], [6, 6], [6, 6]]
expect_r2 = [[6], [0], [6], [6], [6]]
expect_s2_ = [[6, 6, 6, 6], [2, 2, 2, 2], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]]
np.testing.assert_almost_equal(b[0].asnumpy(), expect_s2)
np.testing.assert_almost_equal(b[1].asnumpy(), expect_a2)
np.testing.assert_almost_equal(b[2].asnumpy(), expect_r2)
np.testing.assert_almost_equal(b[3].asnumpy(), expect_s2_)