!21682 RL add buffer cpu kernel

Merge pull request !21682 from VectorSL/push-buffer-op-cpu
This commit is contained in:
i-robot 2021-08-14 01:37:30 +00:00 committed by Gitee
commit 2151b927ba
11 changed files with 637 additions and 2 deletions

View File

@ -36,6 +36,7 @@ if(ENABLE_CPU)
"cpu/ps/*.cc"
"cpu/quantum/*.cc"
"cpu/pyfunc/*.cc"
"cpu/rl/*.cc"
)
if(NOT ENABLE_MPI)

View File

@ -157,6 +157,19 @@ class CPUKernel : public kernel::KernelMod {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
ParallelSearchInfo parallel_search_info_;
template <typename T>
inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
if (index >= addr_list.size()) {
MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
}
if ((addr_list[index] == nullptr) || (addr_list[index]->addr == nullptr) || (addr_list[index]->size == 0)) {
MS_LOG(EXCEPTION) << "The device address is empty, address index: " << index;
}
return reinterpret_cast<T *>(addr_list[index]->addr);
}
};
class CPUKernelUtils {

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_CPU_KERNEL(BufferAppend, KernelAttr(), BufferCPUAppendKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,109 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_APPEND_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_APPEND_CPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class BufferCPUAppendKernel : public CPUKernel {
public:
BufferCPUAppendKernel() : element_nums_(0), exp_batch_(0), capacity_(0) {}
~BufferCPUAppendKernel() override = default;
void Init(const CNodePtr &kernel_node) {
auto shapes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "capacity");
exp_batch_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "exp_batch");
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id()));
}
// buffer size
for (auto i : exp_element_list) {
input_size_list_.push_back(i * capacity_);
}
// exp size
for (auto i : exp_element_list) {
input_size_list_.push_back(i * exp_batch_);
}
// count and head
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
output_size_list_.push_back(sizeof(int));
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &) {
auto count_addr = GetDeviceAddress<int>(inputs, 2 * element_nums_);
auto head_addr = GetDeviceAddress<int>(inputs, 2 * element_nums_ + 1);
int index = 0;
if (count_addr[0] <= capacity_ - 1 && head_addr[0] == 0) {
index = count_addr[0];
count_addr[0] = index + exp_batch_;
if (count_addr[0] > capacity_) {
count_addr[0] = capacity_;
head_addr[0] = (exp_batch_ + count_addr[0] - capacity_) % capacity_;
}
} else {
index = head_addr[0];
head_addr[0] = (exp_batch_ + head_addr[0]) % capacity_;
}
// If exp_batch > (capcity_ - index), goto buffer's head
int remain_size = (exp_batch_ > (capacity_ - index)) ? LongToInt(capacity_ - index) : LongToInt(exp_batch_);
int remap_size = (exp_batch_ > (capacity_ - index)) ? LongToInt(exp_batch_ - capacity_ + index) : 0;
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto exp_addr = GetDeviceAddress<unsigned char>(inputs, i + element_nums_);
size_t one_exp_len = exp_element_list[i];
size_t dist_len = one_exp_len;
if (memcpy_s(buffer_addr + IntToSize(index) * one_exp_len, one_exp_len * remain_size, exp_addr,
dist_len * remain_size) != EOK) {
MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed";
}
if (remap_size > 0) {
if (memcpy_s(buffer_addr, one_exp_len * remap_size, exp_addr, dist_len * remap_size) != EOK) {
MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed";
}
}
}
};
CPUKernelUtils::ParallelFor(task, element_nums_);
return true;
}
void InitKernel(const CNodePtr &kernel_node) { return; }
protected:
void InitSizeLists() { return; }
private:
size_t element_nums_;
int64_t exp_batch_;
int64_t capacity_;
std::vector<size_t> exp_element_list;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_APPEND_CPU_KERNEL_H_

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_CPU_KERNEL(BufferGetItem, KernelAttr(), BufferCPUGetKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_GET_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_GET_CPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class BufferCPUGetKernel : public CPUKernel {
public:
BufferCPUGetKernel() : element_nums_(0), capacity_(0) {}
~BufferCPUGetKernel() override = default;
void Init(const CNodePtr &kernel_node) {
auto shapes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "capacity");
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id()));
}
// buffer size
for (auto i : exp_element_list) {
input_size_list_.push_back(i * capacity_);
output_size_list_.push_back(i);
}
// count, head, index
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
auto count_addr = GetDeviceAddress<int>(inputs, element_nums_);
auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
auto index_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
int index = index_addr[0];
if (index_addr[0] < 0) index += count_addr[0];
if (!(index >= 0 && index < count_addr[0])) {
MS_LOG(ERROR) << "The index " << index_addr[0] << " is out of range:[ " << -1 * count_addr[0] << ", "
<< count_addr[0] << ").";
}
int t = count_addr[0] - head_addr[0];
if (index < t) {
index += head_addr[0];
} else {
index -= t;
}
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto item_addr = GetDeviceAddress<unsigned char>(outputs, i);
size_t one_exp_len = output_size_list_[i];
size_t dist_len = one_exp_len;
if (memcpy_s(item_addr, one_exp_len, buffer_addr + IntToSize(index) * one_exp_len, dist_len) != EOK) {
MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed";
}
}
};
CPUKernelUtils::ParallelFor(task, element_nums_);
return true;
}
void InitKernel(const CNodePtr &kernel_node) { return; }
protected:
void InitSizeLists() { return; }
private:
size_t element_nums_;
int64_t capacity_;
std::vector<size_t> exp_element_list;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_GET_CPU_KERNEL_H_

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_CPU_KERNEL(BufferSample, KernelAttr(), BufferCPUSampleKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class BufferCPUSampleKernel : public CPUKernel {
public:
BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0) {}
~BufferCPUSampleKernel() override = default;
void Init(const CNodePtr &kernel_node) {
auto shapes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
capacity_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "capacity");
batch_size_ = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "batch_size"));
element_nums_ = shapes.size();
for (size_t i = 0; i < element_nums_; i++) {
exp_element_list.push_back(shapes[i] * UnitSizeInBytes(types[i]->type_id()));
}
// buffer size
for (auto i : exp_element_list) {
input_size_list_.push_back(i * capacity_);
output_size_list_.push_back(i * batch_size_);
exp_size_ += i;
}
// index
input_size_list_.push_back(sizeof(int) * batch_size_);
// count and head
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
auto indexes_addr = GetDeviceAddress<int>(inputs, element_nums_);
auto count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
if ((head_addr[0] > 0 && SizeToLong(batch_size_) > capacity_) ||
(head_addr[0] == 0 && SizeToLong(batch_size_) > count_addr[0])) {
MS_LOG(ERROR) << "The batch size " << batch_size_ << " is larger than total buffer size "
<< std::min(capacity_, IntToLong(count_addr[0]));
}
auto task = [&](size_t start, size_t end) {
for (size_t j = start; j < end; j++) {
int64_t index = IntToSize(indexes_addr[j]);
for (size_t i = 0; i < element_nums_; i++) {
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
auto output_addr = GetDeviceAddress<unsigned char>(outputs, i);
auto one_exp_len = exp_element_list[i];
size_t dist_len = one_exp_len;
if (memcpy_s(output_addr + j * one_exp_len, one_exp_len, buffer_addr + index * one_exp_len, dist_len) !=
EOK) {
MS_LOG(EXCEPTION) << "Launch kernel error: memcpy failed";
}
}
}
};
CPUKernelUtils::ParallelFor(task, batch_size_);
return true;
}
void InitKernel(const CNodePtr &kernel_node) { return; }
protected:
void InitSizeLists() { return; }
private:
size_t element_nums_;
int64_t capacity_;
size_t batch_size_;
int64_t exp_size_;
std::vector<size_t> exp_element_list;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_

View File

@ -20,8 +20,8 @@ __global__ void BufferAppendKernel(const int64_t capacity, const size_t size, co
unsigned char *buffer, const unsigned char *exp) {
size_t index_ = index[0];
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (i >= size / exp_batch * (capacity - index[0])) {
index_ = i - size / exp_batch * (capacity - index[0]);
if (i >= (size / exp_batch) * (capacity - index[0])) {
index_ = i - (size / exp_batch) * (capacity - index[0]); // The exp_batch >= 1, guaranteed by op prim.
} else {
index_ = i + index[0] * size / exp_batch;
}

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore as ms
def create_tensor(capcity, shapes, types):
buffer = []
for i in range(len(shapes)):
buffer.append(Parameter(Tensor(np.zeros(((capcity,)+shapes[i])), types[i]), \
name="buffer" + str(i)))
return buffer
class RLBuffer(nn.Cell):
def __init__(self, batch_size, capcity, shapes, types):
super(RLBuffer, self).__init__()
self.buffer = create_tensor(capcity, shapes, types)
self._capacity = capcity
self._batch_size = batch_size
self.count = Parameter(Tensor(0, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.buffer_append = P.BufferAppend(self._capacity, shapes, types)
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
self.buffer_sample = P.BufferSample(
self._capacity, batch_size, shapes, types)
self.dummy_tensor = Tensor(np.ones(shape=[batch_size]), ms.bool_)
self.rnd_choice_mask = P.RandomChoiceWithMask(count=batch_size)
self.reshape = P.Reshape()
@ms_function
def append(self, exps):
return self.buffer_append(self.buffer, exps, self.count, self.head)
@ms_function
def get(self, index):
return self.buffer_get(self.buffer, self.count, self.head, index)
@ms_function
def sample(self):
index, _ = self.rnd_choice_mask(self.dummy_tensor)
index = self.reshape(index, (self._batch_size,))
return self.buffer_sample(self.buffer, index, self.count, self.head)
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
a = Tensor(np.array([0, 1]), ms.int32)
r = Tensor(np.array([1]), ms.float32)
s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32)
exp = [s, a, r, s_]
exp1 = [s_, a, r, s]
@ pytest.mark.level0
@ pytest.mark.platform_x86_cpu
@ pytest.mark.env_onecard
def test_Buffer():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
buffer = RLBuffer(batch_size=32, capcity=100, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.float32, ms.float32])
print("init buffer:\n", buffer.buffer)
for _ in range(0, 110):
buffer.append(exp)
buffer.append(exp1)
print("buffer append:\n", buffer.buffer)
b = buffer.get(-1)
print("buffer get:\n", b)
bs = buffer.sample()
print("buffer sample:\n", bs)

View File

@ -0,0 +1,156 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore as ms
class RLBufferAppend(nn.Cell):
def __init__(self, capcity, shapes, types):
super(RLBufferAppend, self).__init__()
self._capacity = capcity
self.count = Parameter(Tensor(0, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.buffer_append = P.BufferAppend(self._capacity, shapes, types)
@ms_function
def construct(self, buffer, exps):
return self.buffer_append(buffer, exps, self.count, self.head)
class RLBufferGet(nn.Cell):
def __init__(self, capcity, shapes, types):
super(RLBufferGet, self).__init__()
self._capacity = capcity
self.count = Parameter(Tensor(5, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
@ms_function
def construct(self, buffer, index):
return self.buffer_get(buffer, self.count, self.head, index)
class RLBufferSample(nn.Cell):
def __init__(self, capcity, batch_size, shapes, types):
super(RLBufferSample, self).__init__()
self._capacity = capcity
count = 5
self.count = Parameter(Tensor(5, ms.int32), name="count")
self.head = Parameter(Tensor(0, ms.int32), name="head")
self.input_x = Tensor(np.ones(shape=[count]), ms.bool_)
self.buffer_sample = P.BufferSample(self._capacity, batch_size, shapes, types)
self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index")
@ms_function
def construct(self, buffer):
return self.buffer_sample(buffer, self.index, self.count, self.head)
states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0)
actions = Tensor(np.arange(2*5).reshape(5, 2).astype(np.int32))
rewards = Tensor(np.ones((5, 1)).astype(np.int32))
states_ = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32))
b = [states, actions, rewards, states_]
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
a = Tensor(np.array([0, 0]), ms.int32)
r = Tensor(np.array([0]), ms.int32)
s_ = Tensor(np.array([3, 3, 3, 3]), ms.float32)
exp = [s, a, r, s_]
exp1 = [s_, a, r, s]
c = [Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32),
Tensor(np.array([[6, 6], [6, 6]]), ms.int32),
Tensor(np.array([[6], [6]]), ms.int32),
Tensor(np.array([[6, 6, 6, 6], [6, 6, 6, 6]]), ms.float32)]
@ pytest.mark.level0
@ pytest.mark.platform_x86_cpu
@ pytest.mark.env_onecard
def test_BufferSample():
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
ss, aa, rr, ss_ = buffer_sample(b)
expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]]
expect_a = [[0, 1], [4, 5], [8, 9]]
expect_r = [[1], [1], [1]]
expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]]
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
@ pytest.mark.level0
@ pytest.mark.platform_x86_cpu
@ pytest.mark.env_onecard
def test_BufferGet():
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
buffer_get = RLBufferGet(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
ss, aa, rr, ss_ = buffer_get(b, 1)
expect_s = [0.4, 0.5, 0.6, 0.7]
expect_a = [2, 3]
expect_r = [1]
expect_s_ = [4, 5, 6, 7]
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
@ pytest.mark.level0
@ pytest.mark.platform_x86_cpu
@ pytest.mark.env_onecard
def test_BufferAppend():
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
buffer_append = RLBufferAppend(capcity=5, shapes=[(4,), (2,), (1,), (4,)], types=[
ms.float32, ms.int32, ms.int32, ms.float32])
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp)
buffer_append(b, exp1)
expect_s = [[3, 3, 3, 3], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]
expect_a = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]
expect_r = [[0], [0], [0], [0], [0]]
expect_s_ = [[2, 2, 2, 2], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3]]
np.testing.assert_almost_equal(b[0].asnumpy(), expect_s)
np.testing.assert_almost_equal(b[1].asnumpy(), expect_a)
np.testing.assert_almost_equal(b[2].asnumpy(), expect_r)
np.testing.assert_almost_equal(b[3].asnumpy(), expect_s_)
buffer_append(b, exp1)
buffer_append(b, c)
buffer_append(b, c)
expect_s2 = [[6, 6, 6, 6], [3, 3, 3, 3], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]]
expect_a2 = [[6, 6], [0, 0], [6, 6], [6, 6], [6, 6]]
expect_r2 = [[6], [0], [6], [6], [6]]
expect_s2_ = [[6, 6, 6, 6], [2, 2, 2, 2], [6, 6, 6, 6], [6, 6, 6, 6], [6, 6, 6, 6]]
np.testing.assert_almost_equal(b[0].asnumpy(), expect_s2)
np.testing.assert_almost_equal(b[1].asnumpy(), expect_a2)
np.testing.assert_almost_equal(b[2].asnumpy(), expect_r2)
np.testing.assert_almost_equal(b[3].asnumpy(), expect_s2_)