forked from mindspore-Ecosystem/mindspore
add gpu buffer kernel
This commit is contained in:
parent
c84e09cf45
commit
6b4f02af58
|
@ -967,5 +967,39 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
|
||||||
}
|
}
|
||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t UnitSizeInBytes(const mindspore::TypeId &t) {
|
||||||
|
size_t bytes = 0;
|
||||||
|
switch (t) {
|
||||||
|
case kNumberTypeBool:
|
||||||
|
case kNumberTypeInt8:
|
||||||
|
case kNumberTypeUInt8:
|
||||||
|
bytes = sizeof(int8_t);
|
||||||
|
break;
|
||||||
|
case kNumberTypeInt16:
|
||||||
|
case kNumberTypeUInt16:
|
||||||
|
case kNumberTypeFloat16:
|
||||||
|
bytes = sizeof(int16_t);
|
||||||
|
break;
|
||||||
|
case kNumberTypeInt:
|
||||||
|
case kNumberTypeUInt:
|
||||||
|
case kNumberTypeInt32:
|
||||||
|
case kNumberTypeUInt32:
|
||||||
|
case kNumberTypeFloat:
|
||||||
|
case kNumberTypeFloat32:
|
||||||
|
bytes = sizeof(int32_t);
|
||||||
|
break;
|
||||||
|
case kNumberTypeUInt64:
|
||||||
|
case kNumberTypeInt64:
|
||||||
|
case kNumberTypeFloat64:
|
||||||
|
bytes = sizeof(int64_t);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid types " << t;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -143,6 +143,7 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
|
||||||
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape);
|
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape);
|
||||||
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
|
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
|
||||||
const std::vector<int64_t> &stop);
|
const std::vector<int64_t> &stop);
|
||||||
|
size_t UnitSizeInBytes(const mindspore::TypeId &t);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
|
||||||
|
|
||||||
|
__global__ void BufferAppendKernel(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
|
||||||
|
unsigned char *buffer, const unsigned char *exp) {
|
||||||
|
size_t index_ = index[0];
|
||||||
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
|
if (i >= size / exp_batch * (capacity - index[0])) {
|
||||||
|
index_ = i - size / exp_batch * (capacity - index[0]);
|
||||||
|
} else {
|
||||||
|
index_ = i + index[0] * size / exp_batch;
|
||||||
|
}
|
||||||
|
buffer[index_] = exp[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void IncreaseCountKernel(const int64_t capacity, const int exp_batch, int *count, int *head, int *index) {
|
||||||
|
int index_ = 0;
|
||||||
|
if (count[0] <= capacity - 1 && head[0] == 0) {
|
||||||
|
index_ = count[0];
|
||||||
|
count[0] += exp_batch;
|
||||||
|
if (count[0] > capacity) {
|
||||||
|
count[0] = capacity;
|
||||||
|
head[0] = (exp_batch + count[0] - capacity) % capacity;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
index_ = head[0];
|
||||||
|
if (head[0] == count[0])
|
||||||
|
head[0] = 0;
|
||||||
|
else
|
||||||
|
head[0] = (exp_batch + head[0]) % capacity;
|
||||||
|
}
|
||||||
|
index[0] = index_;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void ReMappingIndexKernel(const int *count, const int *head, const int *origin_index, int *index) {
|
||||||
|
index[0] = origin_index[0];
|
||||||
|
if (index[0] < 0) {
|
||||||
|
index[0] += count[0];
|
||||||
|
}
|
||||||
|
if (!(index[0] >= 0 && index[0] < count[0])) {
|
||||||
|
printf("[ERROR] The index %d is out of range:[%d, %d).", origin_index[0], -1 * count[0], count[0]);
|
||||||
|
index[0] = -1;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int t = count[0] - head[0];
|
||||||
|
if (index[0] < t) {
|
||||||
|
index[0] += head[0];
|
||||||
|
} else {
|
||||||
|
index[0] -= t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void BufferGetItemKernel(const size_t size, const int *index, const size_t one_exp_len,
|
||||||
|
const unsigned char *buffer, unsigned char *out) {
|
||||||
|
if (index[0] == -1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
|
out[i] = buffer[i + index[0] * one_exp_len];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void CheckBatchSizeKernel(const int *count, const int *head, const size_t batch_size,
|
||||||
|
const int64_t capacity) {
|
||||||
|
if ((head[0] > 0 && int64_t(batch_size) > capacity) || (head[0] == 0 && batch_size > size_t(count[0]))) {
|
||||||
|
printf("[ERROR] The batch size %d is larger than total buffer size %d", static_cast<int>(batch_size),
|
||||||
|
(capacity > static_cast<int64_t>(count[0]) ? static_cast<int>(count[0]) : static_cast<int>(capacity)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void BufferSampleKernel(const size_t size, const size_t one_element, const int *index,
|
||||||
|
const unsigned char *buffer, unsigned char *out) {
|
||||||
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
|
out[i] = buffer[index[i / one_element] * one_element + i % one_element];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
|
||||||
|
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream) {
|
||||||
|
BufferAppendKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(capacity, size, index, exp_batch, buffer, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
void IncreaseCount(const int64_t capacity, const int exp_batch, int *count, int *head, int *index,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
|
IncreaseCountKernel<<<1, 1, 0, cuda_stream>>>(capacity, exp_batch, count, head, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReMappingIndex(const int *count, const int *head, const int *origin_index, int *index, cudaStream_t cuda_stream) {
|
||||||
|
ReMappingIndexKernel<<<1, 1, 0, cuda_stream>>>(count, head, origin_index, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferGetItem(const size_t size, const int *index, const size_t one_exp_len, const unsigned char *buffer,
|
||||||
|
unsigned char *out, cudaStream_t cuda_stream) {
|
||||||
|
BufferGetItemKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, index, one_exp_len, buffer, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckBatchSize(const int *count, const int *head, const size_t batch_size, const int64_t capacity,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
|
CheckBatchSizeKernel<<<1, 1, 0, cuda_stream>>>(count, head, batch_size, capacity);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
|
||||||
|
unsigned char *out, cudaStream_t cuda_stream) {
|
||||||
|
BufferSampleKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, one_element, index, buffer, out);
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
|
||||||
|
|
||||||
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
|
||||||
|
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream);
|
||||||
|
void IncreaseCount(const int64_t capacity, const int exp_batch, int *count, int *head, int *index,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
void ReMappingIndex(const int *count, const int *head, const int *origin_index, int *index, cudaStream_t cuda_stream);
|
||||||
|
void BufferGetItem(const size_t size, const int *index, const size_t one_exp_len, const unsigned char *buffer,
|
||||||
|
unsigned char *out, cudaStream_t cuda_stream);
|
||||||
|
void CheckBatchSize(const int *count, const int *head, const size_t batch_size, const int64_t capacity,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
|
||||||
|
unsigned char *out, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
|
|
@ -17,6 +17,7 @@
|
||||||
#include "backend/kernel_compiler/gpu/data/dataset_init_kernel.h"
|
#include "backend/kernel_compiler/gpu/data/dataset_init_kernel.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "runtime/device/gpu/gpu_buffer_mgr.h"
|
#include "runtime/device/gpu/gpu_buffer_mgr.h"
|
||||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "profiler/device/gpu/gpu_profiling.h"
|
#include "profiler/device/gpu/gpu_profiling.h"
|
||||||
#include "runtime/device/gpu/gpu_buffer_mgr.h"
|
#include "runtime/device/gpu/gpu_buffer_mgr.h"
|
||||||
#include "runtime/device/gpu/gpu_common.h"
|
#include "runtime/device/gpu/gpu_common.h"
|
||||||
|
|
|
@ -20,40 +20,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
size_t UnitSizeInBytes(const mindspore::TypeId &t) {
|
|
||||||
size_t bytes = 0;
|
|
||||||
switch (t) {
|
|
||||||
case kNumberTypeBool:
|
|
||||||
case kNumberTypeInt8:
|
|
||||||
case kNumberTypeUInt8:
|
|
||||||
bytes = sizeof(int8_t);
|
|
||||||
break;
|
|
||||||
case kNumberTypeInt16:
|
|
||||||
case kNumberTypeUInt16:
|
|
||||||
case kNumberTypeFloat16:
|
|
||||||
bytes = sizeof(int16_t);
|
|
||||||
break;
|
|
||||||
case kNumberTypeInt:
|
|
||||||
case kNumberTypeUInt:
|
|
||||||
case kNumberTypeInt32:
|
|
||||||
case kNumberTypeUInt32:
|
|
||||||
case kNumberTypeFloat:
|
|
||||||
case kNumberTypeFloat32:
|
|
||||||
bytes = sizeof(int32_t);
|
|
||||||
break;
|
|
||||||
case kNumberTypeUInt64:
|
|
||||||
case kNumberTypeInt64:
|
|
||||||
case kNumberTypeFloat64:
|
|
||||||
bytes = sizeof(int64_t);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid types " << t;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
int ElementNums(const std::vector<int> &shape) {
|
int ElementNums(const std::vector<int> &shape) {
|
||||||
if (shape.size() == 0) {
|
if (shape.size() == 0) {
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include "ir/dtype/type.h"
|
#include "ir/dtype/type.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
size_t UnitSizeInBytes(const mindspore::TypeId &t);
|
|
||||||
int ElementNums(const std::vector<int> &shape);
|
int ElementNums(const std::vector<int> &shape);
|
||||||
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types);
|
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
|
||||||
using mindspore::tensor::Tensor;
|
using mindspore::tensor::Tensor;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/kernel_compiler/gpu/rl/buffer_append_gpu_kernel.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
|
||||||
|
#include "runtime/device/gpu/gpu_common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
BufferAppendKernel::BufferAppendKernel() : element_nums_(0), exp_batch_(0), capacity_(0) {}
|
||||||
|
|
||||||
|
BufferAppendKernel::~BufferAppendKernel() {}
|
||||||
|
|
||||||
|
void BufferAppendKernel::ReleaseResource() {}
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferAppendKernel::GetInputSizeList() const { return input_size_list_; }
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferAppendKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferAppendKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool BufferAppendKernel::Init(const CNodePtr &kernel_node) {
|
||||||
|
kernel_node_ = kernel_node;
|
||||||
|
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
|
||||||
|
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
|
||||||
|
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
|
||||||
|
exp_batch_ = GetAttr<int64_t>(kernel_node, "exp_batch");
|
||||||
|
element_nums_ = shapes.size();
|
||||||
|
for (size_t i = 0; i < element_nums_; i++) {
|
||||||
|
exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id()));
|
||||||
|
}
|
||||||
|
// buffer size
|
||||||
|
for (auto i : exp_element_list) {
|
||||||
|
input_size_list_.push_back(i * capacity_);
|
||||||
|
}
|
||||||
|
// exp size
|
||||||
|
for (auto i : exp_element_list) {
|
||||||
|
input_size_list_.push_back(i * exp_batch_);
|
||||||
|
}
|
||||||
|
// count and head
|
||||||
|
input_size_list_.push_back(sizeof(int));
|
||||||
|
input_size_list_.push_back(sizeof(int));
|
||||||
|
output_size_list_.push_back(0);
|
||||||
|
workspace_size_list_.push_back(sizeof(int));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferAppendKernel::InitSizeLists() { return; }
|
||||||
|
|
||||||
|
bool BufferAppendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &, void *stream) {
|
||||||
|
int *count_addr = GetDeviceAddress<int>(inputs, 2 * element_nums_);
|
||||||
|
int *head_addr = GetDeviceAddress<int>(inputs, 2 * element_nums_ + 1);
|
||||||
|
int *index_addr = GetDeviceAddress<int>(workspace, 0);
|
||||||
|
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||||
|
IncreaseCount(capacity_, LongToInt(exp_batch_), count_addr, head_addr, index_addr, cuda_stream);
|
||||||
|
for (size_t i = 0; i < element_nums_; i++) {
|
||||||
|
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
|
||||||
|
auto exp_addr = GetDeviceAddress<unsigned char>(inputs, i + element_nums_);
|
||||||
|
size_t one_exp_len = input_size_list_[i + element_nums_];
|
||||||
|
BufferAppend(capacity_, one_exp_len, index_addr, LongToInt(exp_batch_), buffer_addr, exp_addr, cuda_stream);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_APPEND_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_APPEND_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class BufferAppendKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
BufferAppendKernel();
|
||||||
|
~BufferAppendKernel();
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
bool Init(const CNodePtr &kernel_node) override;
|
||||||
|
void ReleaseResource() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitSizeLists() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t element_nums_;
|
||||||
|
int64_t exp_batch_;
|
||||||
|
int64_t capacity_;
|
||||||
|
std::vector<size_t> exp_element_list;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL(BufferAppend, BufferAppendKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_APPEND_GPU_KERNEL_H_
|
|
@ -0,0 +1,87 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/kernel_compiler/gpu/rl/buffer_get_gpu_kernel.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
|
||||||
|
#include "runtime/device/gpu/gpu_common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
BufferGetKernel::BufferGetKernel() : element_nums_(0), capacity_(0) {}
|
||||||
|
|
||||||
|
BufferGetKernel::~BufferGetKernel() {}
|
||||||
|
|
||||||
|
void BufferGetKernel::ReleaseResource() {}
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferGetKernel::GetInputSizeList() const { return input_size_list_; }
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferGetKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferGetKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool BufferGetKernel::Init(const CNodePtr &kernel_node) {
|
||||||
|
kernel_node_ = kernel_node;
|
||||||
|
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
|
||||||
|
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
|
||||||
|
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
|
||||||
|
element_nums_ = shapes.size();
|
||||||
|
for (size_t i = 0; i < element_nums_; i++) {
|
||||||
|
exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id()));
|
||||||
|
}
|
||||||
|
// buffer size
|
||||||
|
for (auto i : exp_element_list) {
|
||||||
|
input_size_list_.push_back(i * capacity_);
|
||||||
|
output_size_list_.push_back(i);
|
||||||
|
}
|
||||||
|
// count, head, index
|
||||||
|
input_size_list_.push_back(sizeof(int));
|
||||||
|
input_size_list_.push_back(sizeof(int));
|
||||||
|
input_size_list_.push_back(sizeof(int));
|
||||||
|
workspace_size_list_.push_back(sizeof(int));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferGetKernel::InitSizeLists() { return; }
|
||||||
|
|
||||||
|
bool BufferGetKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream) {
|
||||||
|
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_);
|
||||||
|
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
|
||||||
|
int *origin_index_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
|
||||||
|
int *index_addr = GetDeviceAddress<int>(workspace, 0);
|
||||||
|
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||||
|
ReMappingIndex(count_addr, head_addr, origin_index_addr, index_addr, cuda_stream);
|
||||||
|
cudaError_t error = cudaGetLastError();
|
||||||
|
if (error != cudaSuccess) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < element_nums_; i++) {
|
||||||
|
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
|
||||||
|
auto item_addr = GetDeviceAddress<unsigned char>(outputs, i);
|
||||||
|
size_t one_exp_len = output_size_list_[i];
|
||||||
|
BufferGetItem(one_exp_len, index_addr, one_exp_len, buffer_addr, item_addr, cuda_stream);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_GET_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_GET_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class BufferGetKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
BufferGetKernel();
|
||||||
|
~BufferGetKernel();
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
bool Init(const CNodePtr &kernel_node) override;
|
||||||
|
void ReleaseResource() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitSizeLists() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t element_nums_;
|
||||||
|
int64_t capacity_;
|
||||||
|
std::vector<size_t> exp_element_list;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL(BufferGetItem, BufferGetKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_GET_GPU_KERNEL_H_
|
|
@ -0,0 +1,81 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/kernel_compiler/gpu/rl/buffer_sample_gpu_kernel.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
|
||||||
|
#include "runtime/device/gpu/gpu_common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0) {}
|
||||||
|
|
||||||
|
BufferSampleKernel::~BufferSampleKernel() {}
|
||||||
|
|
||||||
|
void BufferSampleKernel::ReleaseResource() {}
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferSampleKernel::GetInputSizeList() const { return input_size_list_; }
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferSampleKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||||
|
|
||||||
|
const std::vector<size_t> &BufferSampleKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
|
||||||
|
kernel_node_ = kernel_node;
|
||||||
|
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
|
||||||
|
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
|
||||||
|
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
|
||||||
|
batch_size_ = LongToSize(GetAttr<int64_t>(kernel_node, "batch_size"));
|
||||||
|
element_nums_ = shapes.size();
|
||||||
|
for (size_t i = 0; i < element_nums_; i++) {
|
||||||
|
auto element = shapes[i] * UnitSizeInBytes(types[i]->type_id());
|
||||||
|
exp_element_list.push_back(element);
|
||||||
|
input_size_list_.push_back(capacity_ * element);
|
||||||
|
output_size_list_.push_back(batch_size_ * element);
|
||||||
|
}
|
||||||
|
// index
|
||||||
|
input_size_list_.push_back(sizeof(int) * batch_size_);
|
||||||
|
// count and head
|
||||||
|
input_size_list_.push_back(sizeof(int));
|
||||||
|
input_size_list_.push_back(sizeof(int));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferSampleKernel::InitSizeLists() { return; }
|
||||||
|
|
||||||
|
bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream) {
|
||||||
|
int *index_addr = GetDeviceAddress<int>(inputs, element_nums_);
|
||||||
|
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
|
||||||
|
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
|
||||||
|
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||||
|
CheckBatchSize(count_addr, head_addr, batch_size_, capacity_, cuda_stream);
|
||||||
|
for (size_t i = 0; i < element_nums_; i++) {
|
||||||
|
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
|
||||||
|
auto out_addr = GetDeviceAddress<unsigned char>(outputs, i);
|
||||||
|
size_t size = batch_size_ * exp_element_list[i];
|
||||||
|
BufferSample(size, exp_element_list[i], index_addr, buffer_addr, out_addr, cuda_stream);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class BufferSampleKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
BufferSampleKernel();
|
||||||
|
~BufferSampleKernel();
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
bool Init(const CNodePtr &kernel_node) override;
|
||||||
|
void ReleaseResource() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitSizeLists() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t element_nums_;
|
||||||
|
int64_t capacity_;
|
||||||
|
size_t batch_size_;
|
||||||
|
std::vector<size_t> exp_element_list;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
|
||||||
|
MS_REG_GPU_KERNEL(BufferSample, BufferSampleKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RL_BUFFER_SAMPLE_GPU_KERNEL_H_
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
||||||
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "runtime/device/gpu/trt_loader.h"
|
#include "runtime/device/gpu/trt_loader.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
|
||||||
|
def create_tensor(capcity, shapes, dtypes):
|
||||||
|
buffer = []
|
||||||
|
for i in range(len(shapes)):
|
||||||
|
buffer.append(Tensor(np.zeros(((capcity,)+shapes[i])), dtypes[i]))
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
class RLBuffer(nn.Cell):
|
||||||
|
def __init__(self, batch_size, capcity, shapes, types):
|
||||||
|
super(RLBuffer, self).__init__()
|
||||||
|
self.buffer = create_tensor(capcity, shapes, types)
|
||||||
|
self._capacity = capcity
|
||||||
|
self.count = Parameter(Tensor(0, ms.int32), name="count")
|
||||||
|
self.head = Parameter(Tensor(0, ms.int32), name="head")
|
||||||
|
self.buffer_append = P.BufferAppend(self._capacity, shapes, types)
|
||||||
|
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
|
||||||
|
self.buffer_sample = P.BufferSample(
|
||||||
|
self._capacity, batch_size, shapes, types)
|
||||||
|
self.randperm = P.Randperm(max_length=capcity, pad=-1)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def append(self, exps):
|
||||||
|
return self.buffer_append(self.buffer, exps, self.count, self.head)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def get(self, index):
|
||||||
|
return self.buffer_get(self.buffer, self.count, self.head, index)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def sample(self):
|
||||||
|
count = self.reshape(self.count, (1,))
|
||||||
|
index = self.randperm(count)
|
||||||
|
return self.buffer_sample(self.buffer, index, self.count, self.head)
|
||||||
|
|
||||||
|
|
||||||
|
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
|
||||||
|
a = Tensor(np.array([0, 1]), ms.int32)
|
||||||
|
r = Tensor(np.array([1]), ms.float32)
|
||||||
|
s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32)
|
||||||
|
exp = [s, a, r, s_]
|
||||||
|
exp1 = [s_, a, r, s]
|
||||||
|
|
||||||
|
|
||||||
|
@ pytest.mark.level0
|
||||||
|
@ pytest.mark.platform_x86_gpu_training
|
||||||
|
@ pytest.mark.env_onecard
|
||||||
|
def test_Buffer():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
buffer = RLBuffer(batch_size=32, capcity=100, shapes=[(4,), (2,), (1,), (4,)], types=[
|
||||||
|
ms.float32, ms.int32, ms.float32, ms.float32])
|
||||||
|
print("init buffer:\n", buffer.buffer)
|
||||||
|
for _ in range(0, 110):
|
||||||
|
buffer.append(exp)
|
||||||
|
buffer.append(exp1)
|
||||||
|
print("buffer append:\n", buffer.buffer)
|
||||||
|
b = buffer.get(-1)
|
||||||
|
print("buffer get:\n", b)
|
||||||
|
bs = buffer.sample()
|
||||||
|
print("buffer sample:\n", bs)
|
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
|
||||||
|
class RLBufferAppend(nn.Cell):
|
||||||
|
def __init__(self, capcity, shapes, types):
|
||||||
|
super(RLBufferAppend, self).__init__()
|
||||||
|
self._capacity = capcity
|
||||||
|
self.count = Parameter(Tensor(0, ms.int32), name="count")
|
||||||
|
self.head = Parameter(Tensor(0, ms.int32), name="head")
|
||||||
|
self.buffer_append = P.BufferAppend(self._capacity, shapes, types)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, buffer, exps):
|
||||||
|
return self.buffer_append(buffer, exps, self.count, self.head)
|
||||||
|
|
||||||
|
|
||||||
|
class RLBufferGet(nn.Cell):
|
||||||
|
def __init__(self, capcity, shapes, types):
|
||||||
|
super(RLBufferGet, self).__init__()
|
||||||
|
self._capacity = capcity
|
||||||
|
self.count = Parameter(Tensor(5, ms.int32), name="count")
|
||||||
|
self.head = Parameter(Tensor(0, ms.int32), name="head")
|
||||||
|
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, buffer, index):
|
||||||
|
return self.buffer_get(buffer, self.count, self.head, index)
|
||||||
|
|
||||||
|
|
||||||
|
class RLBufferSample(nn.Cell):
|
||||||
|
def __init__(self, capcity, batch_size, shapes, types):
|
||||||
|
super(RLBufferSample, self).__init__()
|
||||||
|
self._capacity = capcity
|
||||||
|
count = 5
|
||||||
|
self.count = Parameter(Tensor(5, ms.int32), name="count")
|
||||||
|
self.head = Parameter(Tensor(0, ms.int32), name="head")
|
||||||
|
self.input_x = Tensor(np.ones(shape=[count]), ms.bool_)
|
||||||
|
self.buffer_sample = P.BufferSample(
|
||||||
|
self._capacity, batch_size, shapes, types)
|
||||||
|
self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index")
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, buffer):
|
||||||
|
return self.buffer_sample(buffer, self.index, self.count, self.head)
|
||||||
|
|
||||||
|
|
||||||
|
states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0)
|
||||||
|
actions = Tensor(np.arange(2*5).reshape(5, 2).astype(np.int32))
|
||||||
|
rewards = Tensor(np.ones((5, 1)).astype(np.int32))
|
||||||
|
states_ = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32))
|
||||||
|
b = [states, actions, rewards, states_]
|
||||||
|
|
||||||
|
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
|
||||||
|
a = Tensor(np.array([0, 0]), ms.int32)
|
||||||
|
r = Tensor(np.array([0]), ms.int32)
|
||||||
|
s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32)
|
||||||
|
exp = [s, a, r, s_]
|
||||||
|
exp1 = [s_, a, r, s]
|
||||||
|
|
||||||
|
c = [Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32),
|
||||||
|
Tensor(np.array([[6, 6], [6, 6]]), ms.int32),
|
||||||
|
Tensor(np.array([[6], [6]]), ms.int32),
|
||||||
|
Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32)]
|
||||||
|
|
||||||
|
@ pytest.mark.level0
|
||||||
|
@ pytest.mark.platform_x86_gpu_training
|
||||||
|
@ pytest.mark.env_onecard
|
||||||
|
def test_BufferSample():
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[
|
||||||
|
ms.float32, ms.int32, ms.int32, ms.float32])
|
||||||
|
ss, aa, rr, ss_ = buffer_sample(b)
|
||||||
|
expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]]
|
||||||
|
expect_a = [[0, 1], [4, 5], [8, 9]]
|
||||||
|
expect_r = [[1], [1], [1]]
|
||||||
|
expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]]
|
||||||
|
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
|
||||||
|
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
|
||||||
|
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
|
||||||
|
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
|
||||||
|
|
||||||
|
|
||||||
|
@ pytest.mark.level0
|
||||||
|
@ pytest.mark.platform_x86_gpu_training
|
||||||
|
@ pytest.mark.env_onecard
|
||||||
|
def test_BufferGet():
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
buffer_get = RLBufferGet(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[
|
||||||
|
ms.float32, ms.int32, ms.int32, ms.float32])
|
||||||
|
ss, aa, rr, ss_ = buffer_get(b, 1)
|
||||||
|
expect_s = [0.4, 0.5, 0.6, 0.7]
|
||||||
|
expect_a = [2, 3]
|
||||||
|
expect_r = [1]
|
||||||
|
expect_s_ = [4, 5, 6, 7]
|
||||||
|
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
|
||||||
|
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
|
||||||
|
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
|
||||||
|
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
|
||||||
|
|
||||||
|
|
||||||
|
@ pytest.mark.level0
|
||||||
|
@ pytest.mark.platform_x86_gpu_training
|
||||||
|
@ pytest.mark.env_onecard
|
||||||
|
def test_BufferAppend():
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
buffer_append = RLBufferAppend(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[
|
||||||
|
ms.float32, ms.int32, ms.int32, ms.float32])
|
||||||
|
|
||||||
|
buffer_append(b, exp)
|
||||||
|
buffer_append(b, exp)
|
||||||
|
buffer_append(b, exp)
|
||||||
|
buffer_append(b, exp)
|
||||||
|
buffer_append(b, exp)
|
||||||
|
buffer_append(b, exp1)
|
||||||
|
expect_s = [[3, 3, 3, 3], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]
|
||||||
|
expect_a = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]
|
||||||
|
expect_r = [[0], [0], [0], [0], [0]]
|
||||||
|
expect_s_ = [[2, 2, 2, 2], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3]]
|
||||||
|
np.testing.assert_almost_equal(b[0].asnumpy(), expect_s)
|
||||||
|
np.testing.assert_almost_equal(b[1].asnumpy(), expect_a)
|
||||||
|
np.testing.assert_almost_equal(b[2].asnumpy(), expect_r)
|
||||||
|
np.testing.assert_almost_equal(b[3].asnumpy(), expect_s_)
|
||||||
|
buffer_append(b, exp1)
|
||||||
|
buffer_append(b, c)
|
||||||
|
buffer_append(b, c)
|
||||||
|
expect_s2 = [[6, 6, 6, 6], [3, 3, 3, 3], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]]
|
||||||
|
expect_a2 = [[6, 6], [0, 0], [6, 6], [6, 6], [6, 6]]
|
||||||
|
expect_r2 = [[6], [0], [6], [6], [6]]
|
||||||
|
expect_s2_ = [[6, 6, 6, 6], [2, 2, 2, 2], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]]
|
||||||
|
np.testing.assert_almost_equal(b[0].asnumpy(), expect_s2)
|
||||||
|
np.testing.assert_almost_equal(b[1].asnumpy(), expect_a2)
|
||||||
|
np.testing.assert_almost_equal(b[2].asnumpy(), expect_r2)
|
||||||
|
np.testing.assert_almost_equal(b[3].asnumpy(), expect_s2_)
|
Loading…
Reference in New Issue