diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index 954402e5c9e..90496bcfc96 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -36,6 +36,7 @@ if(ENABLE_CPU) "cpu/ps/*.cc" "cpu/quantum/*.cc" "cpu/pyfunc/*.cc" + "cpu/rl/*.cc" ) if(NOT ENABLE_MPI) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index b85568f505e..a0714b3dc76 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -152,6 +152,19 @@ class CPUKernel : public kernel::KernelMod { std::vector output_size_list_; std::vector workspace_size_list_; ParallelSearchInfo parallel_search_info_; + + template + inline T *GetDeviceAddress(const std::vector &addr_list, size_t index) { + if (index >= addr_list.size()) { + MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; + } + + if ((addr_list[index] == nullptr) || (addr_list[index]->addr == nullptr) || (addr_list[index]->size == 0)) { + MS_LOG(EXCEPTION) << "The device address is empty, address index: " << index; + } + + return reinterpret_cast(addr_list[index]->addr); + } }; class CPUKernelUtils { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.cc new file mode 100644 index 00000000000..5c2cf9be9ec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_CPU_KERNEL(BufferAppend, KernelAttr(), BufferCPUAppendKernel); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h new file mode 100644 index 00000000000..973e4e63345 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h @@ -0,0 +1,109 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_APPEND_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_APPEND_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BufferCPUAppendKernel : public CPUKernel { + public: + BufferCPUAppendKernel() : element_nums_(0), exp_batch_(0), capacity_(0) {} + + ~BufferCPUAppendKernel() override = default; + void Init(const CNodePtr &kernel_node) { + auto shapes = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_elements"); + auto types = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_dtype"); + capacity_ = AnfAlgo::GetNodeAttr(kernel_node, "capacity"); + exp_batch_ = AnfAlgo::GetNodeAttr(kernel_node, "exp_batch"); + element_nums_ = shapes.size(); + for (size_t i = 0; i < element_nums_; i++) { + exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id())); + } + // buffer size + for (auto i : exp_element_list) { + input_size_list_.push_back(i * capacity_); + } + // exp size + for (auto i : exp_element_list) { + input_size_list_.push_back(i * exp_batch_); + } + // count and head + input_size_list_.push_back(sizeof(int)); + input_size_list_.push_back(sizeof(int)); + output_size_list_.push_back(sizeof(int)); + } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &) { + auto count_addr = GetDeviceAddress(inputs, 2 * element_nums_); + auto head_addr = GetDeviceAddress(inputs, 2 * element_nums_ + 1); + int index = 0; + if (count_addr[0] <= capacity_ - 1 && head_addr[0] == 0) { + index = count_addr[0]; + count_addr[0] = index + exp_batch_; + if (count_addr[0] > capacity_) { + count_addr[0] = capacity_; + head_addr[0] = (exp_batch_ + count_addr[0] - capacity_) % capacity_; + } + } else { + index = head_addr[0]; + head_addr[0] = (exp_batch_ + head_addr[0]) % capacity_; + } + // If exp_batch > (capcity_ - index), goto buffer's head + int remain_size = (exp_batch_ > (capacity_ - index)) ? LongToInt(capacity_ - index) : LongToInt(exp_batch_); + int remap_size = (exp_batch_ > (capacity_ - index)) ? LongToInt(exp_batch_ - capacity_ + index) : 0; + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + auto buffer_addr = GetDeviceAddress(inputs, i); + auto exp_addr = GetDeviceAddress(inputs, i + element_nums_); + size_t one_exp_len = exp_element_list[i]; + size_t dist_len = one_exp_len; + if (memcpy_s(buffer_addr + IntToSize(index) * one_exp_len, one_exp_len * remain_size, exp_addr, + dist_len * remain_size) != EOK) { + MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed"; + } + if (remap_size > 0) { + if (memcpy_s(buffer_addr, one_exp_len * remap_size, exp_addr, dist_len * remap_size) != EOK) { + MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed"; + } + } + } + }; + CPUKernelUtils::ParallelFor(task, element_nums_); + return true; + } + + void InitKernel(const CNodePtr &kernel_node) { return; } + + protected: + void InitSizeLists() { return; } + + private: + size_t element_nums_; + int64_t exp_batch_; + int64_t capacity_; + std::vector exp_element_list; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_APPEND_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.cc new file mode 100644 index 00000000000..4f5810390c3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_CPU_KERNEL(BufferGetItem, KernelAttr(), BufferCPUGetKernel); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h new file mode 100644 index 00000000000..7f7137c3602 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_GET_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_GET_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BufferCPUGetKernel : public CPUKernel { + public: + BufferCPUGetKernel() : element_nums_(0), capacity_(0) {} + + ~BufferCPUGetKernel() override = default; + void Init(const CNodePtr &kernel_node) { + auto shapes = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_elements"); + auto types = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_dtype"); + capacity_ = AnfAlgo::GetNodeAttr(kernel_node, "capacity"); + element_nums_ = shapes.size(); + for (size_t i = 0; i < element_nums_; i++) { + exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id())); + } + // buffer size + for (auto i : exp_element_list) { + input_size_list_.push_back(i * capacity_); + output_size_list_.push_back(i); + } + // count, head, index + input_size_list_.push_back(sizeof(int)); + input_size_list_.push_back(sizeof(int)); + input_size_list_.push_back(sizeof(int)); + } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + auto count_addr = GetDeviceAddress(inputs, element_nums_); + auto head_addr = GetDeviceAddress(inputs, element_nums_ + 1); + auto index_addr = GetDeviceAddress(inputs, element_nums_ + 2); + int index = index_addr[0]; + if (index_addr[0] < 0) index += count_addr[0]; + if (!(index >= 0 && index < count_addr[0])) { + MS_LOG(ERROR) << "The index " << index_addr[0] << " is out of range:[ " << -1 * count_addr[0] << ", " + << count_addr[0] << ")."; + } + int t = count_addr[0] - head_addr[0]; + if (index < t) { + index += head_addr[0]; + } else { + index -= t; + } + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + auto buffer_addr = GetDeviceAddress(inputs, i); + auto item_addr = GetDeviceAddress(outputs, i); + size_t one_exp_len = output_size_list_[i]; + size_t dist_len = one_exp_len; + if (memcpy_s(item_addr, one_exp_len, buffer_addr + IntToSize(index) * one_exp_len, dist_len) != EOK) { + MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed"; + } + } + }; + CPUKernelUtils::ParallelFor(task, element_nums_); + return true; + } + + void InitKernel(const CNodePtr &kernel_node) { return; } + + protected: + void InitSizeLists() { return; } + + private: + size_t element_nums_; + int64_t capacity_; + std::vector exp_element_list; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_GET_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.cc new file mode 100644 index 00000000000..ebe3c3a5695 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_CPU_KERNEL(BufferSample, KernelAttr(), BufferCPUSampleKernel); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h new file mode 100644 index 00000000000..4a18c47b7c7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h @@ -0,0 +1,100 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BufferCPUSampleKernel : public CPUKernel { + public: + BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0) {} + + ~BufferCPUSampleKernel() override = default; + void Init(const CNodePtr &kernel_node) { + auto shapes = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_elements"); + auto types = AnfAlgo::GetNodeAttr>(kernel_node, "buffer_dtype"); + capacity_ = AnfAlgo::GetNodeAttr(kernel_node, "capacity"); + batch_size_ = LongToSize(AnfAlgo::GetNodeAttr(kernel_node, "batch_size")); + element_nums_ = shapes.size(); + for (size_t i = 0; i < element_nums_; i++) { + exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id())); + } + // buffer size + for (auto i : exp_element_list) { + input_size_list_.push_back(i * capacity_); + output_size_list_.push_back(i * batch_size_); + exp_size_ += i; + } + // index + input_size_list_.push_back(sizeof(int) * batch_size_); + // count and head + input_size_list_.push_back(sizeof(int)); + input_size_list_.push_back(sizeof(int)); + } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + auto indexes_addr = GetDeviceAddress(inputs, element_nums_); + auto count_addr = GetDeviceAddress(inputs, element_nums_ + 1); + auto head_addr = GetDeviceAddress(inputs, element_nums_ + 2); + + if ((head_addr[0] > 0 && SizeToLong(batch_size_) > capacity_) || + (head_addr[0] == 0 && SizeToLong(batch_size_) > count_addr[0])) { + MS_LOG(ERROR) << "The batch size " << batch_size_ << " is larger than total buffer size " + << std::min(capacity_, IntToLong(count_addr[0])); + } + auto task = [&](size_t start, size_t end) { + for (size_t j = start; j < end; j++) { + int64_t index = IntToSize(indexes_addr[j]); + for (size_t i = 0; i < element_nums_; i++) { + auto buffer_addr = GetDeviceAddress(inputs, i); + auto output_addr = GetDeviceAddress(outputs, i); + auto one_exp_len = exp_element_list[i]; + size_t dist_len = one_exp_len; + if (memcpy_s(output_addr + j * one_exp_len, one_exp_len, buffer_addr + index * one_exp_len, dist_len) != + EOK) { + MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed"; + } + } + } + }; + CPUKernelUtils::ParallelFor(task, batch_size_); + return true; + } + + void InitKernel(const CNodePtr &kernel_node) { return; } + + protected: + void InitSizeLists() { return; } + + private: + size_t element_nums_; + int64_t capacity_; + size_t batch_size_; + int64_t exp_size_; + std::vector exp_element_list; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu index 9027f5eac3c..e374c1eda24 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cu @@ -20,8 +20,8 @@ __global__ void BufferAppendKernel(const int64_t capacity, const size_t size, co unsigned char *buffer, const unsigned char *exp) { size_t index_ = index[0]; for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - if (i >= size / exp_batch * (capacity - index[0])) { - index_ = i - size / exp_batch * (capacity - index[0]); + if (i >= (size / exp_batch) * (capacity - index[0])) { + index_ = i - (size / exp_batch) * (capacity - index[0]); // The exp_batch >= 1, guaranteed by op prim. } else { index_ = i + index[0] * size / exp_batch; } diff --git a/tests/st/ops/cpu/test_rl_buffer_net.py b/tests/st/ops/cpu/test_rl_buffer_net.py new file mode 100644 index 00000000000..c46a6169bb8 --- /dev/null +++ b/tests/st/ops/cpu/test_rl_buffer_net.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +import mindspore as ms + + +def create_tensor(capcity, shapes, types): + buffer = [] + for i in range(len(shapes)): + buffer.append(Parameter(Tensor(np.zeros(((capcity,)+shapes[i])), types[i]), \ + name="buffer" + str(i))) + return buffer + + +class RLBuffer(nn.Cell): + def __init__(self, batch_size, capcity, shapes, types): + super(RLBuffer, self).__init__() + self.buffer = create_tensor(capcity, shapes, types) + self._capacity = capcity + self._batch_size = batch_size + self.count = Parameter(Tensor(0, ms.int32), name="count") + self.head = Parameter(Tensor(0, ms.int32), name="head") + self.buffer_append = P.BufferAppend(self._capacity, shapes, types) + self.buffer_get = P.BufferGetItem(self._capacity, shapes, types) + self.buffer_sample = P.BufferSample( + self._capacity, batch_size, shapes, types) + self.dummy_tensor = Tensor(np.ones(shape=[batch_size]), ms.bool_) + self.rnd_choice_mask = P.RandomChoiceWithMask(count=batch_size) + self.reshape = P.Reshape() + + @ms_function + def append(self, exps): + return self.buffer_append(self.buffer, exps, self.count, self.head) + + @ms_function + def get(self, index): + return self.buffer_get(self.buffer, self.count, self.head, index) + + @ms_function + def sample(self): + index, _ = self.rnd_choice_mask(self.dummy_tensor) + index = self.reshape(index, (self._batch_size,)) + return self.buffer_sample(self.buffer, index, self.count, self.head) + + +s = Tensor(np.array([2, 2, 2, 2]), ms.float32) +a = Tensor(np.array([0, 1]), ms.int32) +r = Tensor(np.array([1]), ms.float32) +s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32) +exp = [s, a, r, s_] +exp1 = [s_, a, r, s] + + +@ pytest.mark.level0 +@ pytest.mark.platform_x86_cpu +@ pytest.mark.env_onecard +def test_Buffer(): + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + buffer = RLBuffer(batch_size=32, capcity=100, shapes=[(4,), (2,), (1,), (4,)], types=[ + ms.float32, ms.int32, ms.float32, ms.float32]) + print("init buffer:\n", buffer.buffer) + for _ in range(0, 110): + buffer.append(exp) + buffer.append(exp1) + print("buffer append:\n", buffer.buffer) + b = buffer.get(-1) + print("buffer get:\n", b) + bs = buffer.sample() + print("buffer sample:\n", bs) diff --git a/tests/st/ops/cpu/test_rl_buffer_op.py b/tests/st/ops/cpu/test_rl_buffer_op.py new file mode 100644 index 00000000000..fd8ed19783a --- /dev/null +++ b/tests/st/ops/cpu/test_rl_buffer_op.py @@ -0,0 +1,156 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +import mindspore as ms + + +class RLBufferAppend(nn.Cell): + def __init__(self, capcity, shapes, types): + super(RLBufferAppend, self).__init__() + self._capacity = capcity + self.count = Parameter(Tensor(0, ms.int32), name="count") + self.head = Parameter(Tensor(0, ms.int32), name="head") + self.buffer_append = P.BufferAppend(self._capacity, shapes, types) + + @ms_function + def construct(self, buffer, exps): + return self.buffer_append(buffer, exps, self.count, self.head) + + +class RLBufferGet(nn.Cell): + def __init__(self, capcity, shapes, types): + super(RLBufferGet, self).__init__() + self._capacity = capcity + self.count = Parameter(Tensor(5, ms.int32), name="count") + self.head = Parameter(Tensor(0, ms.int32), name="head") + self.buffer_get = P.BufferGetItem(self._capacity, shapes, types) + + @ms_function + def construct(self, buffer, index): + return self.buffer_get(buffer, self.count, self.head, index) + + +class RLBufferSample(nn.Cell): + def __init__(self, capcity, batch_size, shapes, types): + super(RLBufferSample, self).__init__() + self._capacity = capcity + count = 5 + self.count = Parameter(Tensor(5, ms.int32), name="count") + self.head = Parameter(Tensor(0, ms.int32), name="head") + self.input_x = Tensor(np.ones(shape=[count]), ms.bool_) + self.buffer_sample = P.BufferSample(self._capacity, batch_size, shapes, types) + self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index") + + @ms_function + def construct(self, buffer): + return self.buffer_sample(buffer, self.index, self.count, self.head) + + +states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0) +actions = Tensor(np.arange(2*5).reshape(5, 2).astype(np.int32)) +rewards = Tensor(np.ones((5, 1)).astype(np.int32)) +states_ = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)) +b = [states, actions, rewards, states_] + +s = Tensor(np.array([2, 2, 2, 2]), ms.float32) +a = Tensor(np.array([0, 0]), ms.int32) +r = Tensor(np.array([0]), ms.int32) +s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32) +exp = [s, a, r, s_] +exp1 = [s_, a, r, s] + +c = [Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32), + Tensor(np.array([[6, 6], [6, 6]]), ms.int32), + Tensor(np.array([[6], [6]]), ms.int32), + Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32)] + +@ pytest.mark.level0 +@ pytest.mark.platform_x86_cpu +@ pytest.mark.env_onecard +def test_BufferSample(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[ + ms.float32, ms.int32, ms.int32, ms.float32]) + ss, aa, rr, ss_ = buffer_sample(b) + expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]] + expect_a = [[0, 1], [4, 5], [8, 9]] + expect_r = [[1], [1], [1]] + expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]] + np.testing.assert_almost_equal(ss.asnumpy(), expect_s) + np.testing.assert_almost_equal(aa.asnumpy(), expect_a) + np.testing.assert_almost_equal(rr.asnumpy(), expect_r) + np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_) + + +@ pytest.mark.level0 +@ pytest.mark.platform_x86_cpu +@ pytest.mark.env_onecard +def test_BufferGet(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + buffer_get = RLBufferGet(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[ + ms.float32, ms.int32, ms.int32, ms.float32]) + ss, aa, rr, ss_ = buffer_get(b, 1) + expect_s = [0.4, 0.5, 0.6, 0.7] + expect_a = [2, 3] + expect_r = [1] + expect_s_ = [4, 5, 6, 7] + np.testing.assert_almost_equal(ss.asnumpy(), expect_s) + np.testing.assert_almost_equal(aa.asnumpy(), expect_a) + np.testing.assert_almost_equal(rr.asnumpy(), expect_r) + np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_) + + +@ pytest.mark.level0 +@ pytest.mark.platform_x86_cpu +@ pytest.mark.env_onecard +def test_BufferAppend(): + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + buffer_append = RLBufferAppend(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[ + ms.float32, ms.int32, ms.int32, ms.float32]) + + buffer_append(b, exp) + buffer_append(b, exp) + buffer_append(b, exp) + buffer_append(b, exp) + buffer_append(b, exp) + buffer_append(b, exp1) + expect_s = [[3, 3, 3, 3], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]] + expect_a = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]] + expect_r = [[0], [0], [0], [0], [0]] + expect_s_ = [[2, 2, 2, 2], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3]] + np.testing.assert_almost_equal(b[0].asnumpy(), expect_s) + np.testing.assert_almost_equal(b[1].asnumpy(), expect_a) + np.testing.assert_almost_equal(b[2].asnumpy(), expect_r) + np.testing.assert_almost_equal(b[3].asnumpy(), expect_s_) + buffer_append(b, exp1) + buffer_append(b, c) + buffer_append(b, c) + expect_s2 = [[6, 6, 6, 6], [3, 3, 3, 3], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]] + expect_a2 = [[6, 6], [0, 0], [6, 6], [6, 6], [6, 6]] + expect_r2 = [[6], [0], [6], [6], [6]] + expect_s2_ = [[6, 6, 6, 6], [2, 2, 2, 2], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]] + np.testing.assert_almost_equal(b[0].asnumpy(), expect_s2) + np.testing.assert_almost_equal(b[1].asnumpy(), expect_a2) + np.testing.assert_almost_equal(b[2].asnumpy(), expect_r2) + np.testing.assert_almost_equal(b[3].asnumpy(), expect_s2_)