forked from mindspore-Ecosystem/mindspore
update buffer sample gpu
This commit is contained in:
parent
2151b927ba
commit
680d319290
|
@ -16,6 +16,7 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RL_BUFFER_SAMPLE_CPU_KERNEL_H_
|
||||
#include <stdlib.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -27,13 +28,14 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
class BufferCPUSampleKernel : public CPUKernel {
|
||||
public:
|
||||
BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0) {}
|
||||
BufferCPUSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), exp_size_(0), seed_(0) {}
|
||||
|
||||
~BufferCPUSampleKernel() override = default;
|
||||
void Init(const CNodePtr &kernel_node) {
|
||||
auto shapes = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
|
||||
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
|
||||
capacity_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "capacity");
|
||||
seed_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed");
|
||||
batch_size_ = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "batch_size"));
|
||||
element_nums_ = shapes.size();
|
||||
for (size_t i = 0; i < element_nums_; i++) {
|
||||
|
@ -45,8 +47,6 @@ class BufferCPUSampleKernel : public CPUKernel {
|
|||
output_size_list_.push_back(i * batch_size_);
|
||||
exp_size_ += i;
|
||||
}
|
||||
// index
|
||||
input_size_list_.push_back(sizeof(int) * batch_size_);
|
||||
// count and head
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
|
@ -54,18 +54,29 @@ class BufferCPUSampleKernel : public CPUKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto indexes_addr = GetDeviceAddress<int>(inputs, element_nums_);
|
||||
auto count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
|
||||
auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
|
||||
auto count_addr = GetDeviceAddress<int>(inputs, element_nums_);
|
||||
auto head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
|
||||
|
||||
if ((head_addr[0] > 0 && SizeToLong(batch_size_) > capacity_) ||
|
||||
(head_addr[0] == 0 && SizeToLong(batch_size_) > count_addr[0])) {
|
||||
MS_LOG(ERROR) << "The batch size " << batch_size_ << " is larger than total buffer size "
|
||||
<< std::min(capacity_, IntToLong(count_addr[0]));
|
||||
}
|
||||
// Generate random indexes
|
||||
std::vector<size_t> indexes;
|
||||
for (size_t i = 0; i < IntToSize(count_addr[0]); ++i) {
|
||||
indexes.push_back(i);
|
||||
}
|
||||
if (seed_ == 0) {
|
||||
std::srand(time(nullptr));
|
||||
} else {
|
||||
std::srand(seed_);
|
||||
}
|
||||
random_shuffle(indexes.begin(), indexes.end());
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t j = start; j < end; j++) {
|
||||
int64_t index = IntToSize(indexes_addr[j]);
|
||||
size_t index = indexes[j];
|
||||
for (size_t i = 0; i < element_nums_; i++) {
|
||||
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
|
||||
auto output_addr = GetDeviceAddress<unsigned char>(outputs, i);
|
||||
|
@ -92,6 +103,7 @@ class BufferCPUSampleKernel : public CPUKernel {
|
|||
int64_t capacity_;
|
||||
size_t batch_size_;
|
||||
int64_t exp_size_;
|
||||
int64_t seed_;
|
||||
std::vector<size_t> exp_element_list;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -91,6 +91,14 @@ __global__ void BufferSampleKernel(const size_t size, const size_t one_element,
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void SrandUniformFloat(const int size, curandState *globalState, const int seedc, float *out) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
curand_init(seedc, threadIdx.x, 0, &globalState[i]);
|
||||
out[i] = curand_uniform(&globalState[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
|
||||
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream) {
|
||||
BufferAppendKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(capacity, size, index, exp_batch, buffer, exp);
|
||||
|
@ -119,3 +127,7 @@ void BufferSample(const size_t size, const size_t one_element, const int *index,
|
|||
unsigned char *out, cudaStream_t cuda_stream) {
|
||||
BufferSampleKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, one_element, index, buffer, out);
|
||||
}
|
||||
|
||||
void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream) {
|
||||
SrandUniformFloat<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, globalState, seedc, out);
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RL_BUFFER_IMPL_H_
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
void BufferAppend(const int64_t capacity, const size_t size, const int *index, const int exp_batch,
|
||||
unsigned char *buffer, const unsigned char *exp, cudaStream_t cuda_stream);
|
||||
|
@ -29,5 +29,5 @@ void CheckBatchSize(const int *count, const int *head, const size_t batch_size,
|
|||
cudaStream_t cuda_stream);
|
||||
void BufferSample(const size_t size, const size_t one_element, const int *index, const unsigned char *buffer,
|
||||
unsigned char *out, cudaStream_t cuda_stream);
|
||||
|
||||
void RandomGen(const int size, curandState *globalState, const int &seedc, float *out, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
|
||||
|
|
|
@ -19,15 +19,18 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <chrono>
|
||||
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/rl/rl_buffer_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0) {}
|
||||
BufferSampleKernel::BufferSampleKernel() : element_nums_(0), capacity_(0), batch_size_(0), seed_(0) {}
|
||||
|
||||
BufferSampleKernel::~BufferSampleKernel() {}
|
||||
|
||||
|
@ -44,6 +47,7 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
|
|||
auto shapes = GetAttr<std::vector<int64_t>>(kernel_node, "buffer_elements");
|
||||
auto types = GetAttr<std::vector<TypePtr>>(kernel_node, "buffer_dtype");
|
||||
capacity_ = GetAttr<int64_t>(kernel_node, "capacity");
|
||||
seed_ = GetAttr<int64_t>(kernel_node, "seed");
|
||||
batch_size_ = LongToSize(GetAttr<int64_t>(kernel_node, "batch_size"));
|
||||
element_nums_ = shapes.size();
|
||||
for (size_t i = 0; i < element_nums_; i++) {
|
||||
|
@ -52,28 +56,52 @@ bool BufferSampleKernel::Init(const CNodePtr &kernel_node) {
|
|||
input_size_list_.push_back(capacity_ * element);
|
||||
output_size_list_.push_back(batch_size_ * element);
|
||||
}
|
||||
// index
|
||||
input_size_list_.push_back(sizeof(int) * batch_size_);
|
||||
// count and head
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
workspace_size_list_.push_back(capacity_ * sizeof(curandState));
|
||||
workspace_size_list_.push_back(capacity_ * sizeof(float));
|
||||
workspace_size_list_.push_back(capacity_ * sizeof(int));
|
||||
workspace_size_list_.push_back(capacity_ * sizeof(float));
|
||||
return true;
|
||||
}
|
||||
|
||||
void BufferSampleKernel::InitSizeLists() { return; }
|
||||
|
||||
bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool BufferSampleKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) {
|
||||
int *index_addr = GetDeviceAddress<int>(inputs, element_nums_);
|
||||
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
|
||||
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 2);
|
||||
int *count_addr = GetDeviceAddress<int>(inputs, element_nums_);
|
||||
int *head_addr = GetDeviceAddress<int>(inputs, element_nums_ + 1);
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||
CheckBatchSize(count_addr, head_addr, batch_size_, capacity_, cuda_stream);
|
||||
int k_cut = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&k_cut, count_addr, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"sync dev to host failed");
|
||||
// 1 Generate random floats
|
||||
auto States = GetDeviceAddress<void *>(workspaces, 0);
|
||||
auto random_f = GetDeviceAddress<float>(workspaces, 1);
|
||||
auto indexes = GetDeviceAddress<int>(workspaces, 2);
|
||||
auto useless_out = GetDeviceAddress<float>(workspaces, 3);
|
||||
int seedc = 0;
|
||||
if (seed_ == 0) {
|
||||
generator_.seed(std::chrono::system_clock::now().time_since_epoch().count());
|
||||
seedc = generator_();
|
||||
} else {
|
||||
seedc = seed_;
|
||||
}
|
||||
|
||||
float init_k = std::numeric_limits<float>::lowest();
|
||||
curandState *devStates = reinterpret_cast<curandState *>(States);
|
||||
RandomGen(k_cut, devStates, seedc, random_f, cuda_stream);
|
||||
// 2 Sort the random floats, and get the sorted indexes as the random indexes
|
||||
FastTopK(1, k_cut, random_f, k_cut, useless_out, indexes, init_k, cuda_stream);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSync failed, sample-topk");
|
||||
for (size_t i = 0; i < element_nums_; i++) {
|
||||
auto buffer_addr = GetDeviceAddress<unsigned char>(inputs, i);
|
||||
auto out_addr = GetDeviceAddress<unsigned char>(outputs, i);
|
||||
size_t size = batch_size_ * exp_element_list[i];
|
||||
BufferSample(size, exp_element_list[i], index_addr, buffer_addr, out_addr, cuda_stream);
|
||||
BufferSample(size, exp_element_list[i], indexes, buffer_addr, out_addr, cuda_stream);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
|
@ -45,6 +46,8 @@ class BufferSampleKernel : public GpuKernel {
|
|||
size_t element_nums_;
|
||||
int64_t capacity_;
|
||||
size_t batch_size_;
|
||||
int64_t seed_;
|
||||
std::mt19937 generator_;
|
||||
std::vector<size_t> exp_element_list;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -36,12 +36,12 @@ class BufferSample(PrimitiveWithInfer):
|
|||
batch_size (int64): The size of the sampled data, lessequal to `capacity`.
|
||||
buffer_shape (tuple(shape)): The shape of an buffer.
|
||||
buffer_dtype (tuple(type)): The type of an buffer.
|
||||
seed (int64): Random seed for sample. Default: 0. If use the default seed, it will generate a ramdom
|
||||
one in kernel. Set a number other than `0` to keep a specific seed.
|
||||
|
||||
Inputs:
|
||||
- **data** (tuple(Parameter(Tensor))) - The tuple(Tensor) represents replaybuffer,
|
||||
each tensor is described by the `buffer_shape` and `buffer_type`.
|
||||
- **indexes** (tuple(int32)) - The position list in replaybuffer,
|
||||
the size equal to `batch_size`.
|
||||
- **count** (Parameter) - The count mean the real available size of the buffer,
|
||||
data type: int32.
|
||||
- **head** (Parameter) - The position of the first data in buffer, data type: int32.
|
||||
|
@ -69,8 +69,7 @@ class BufferSample(PrimitiveWithInfer):
|
|||
Parameter(Tensor(np.ones((100, 1)).astype(np.int32)), name="reward"),
|
||||
Parameter(Tensor(np.arange(100 * 4).reshape(100, 4).astype(np.float32)), name="state_")]
|
||||
>>> buffer_sample = ops.BufferSample(capacity, batch_size, shapes, types)
|
||||
>>> indexes = Parameter(Tensor([0, 2, 4, 3, 8], ms.int32), name="index")
|
||||
>>> output = buffer_sample(buffer, indexes, count, head)
|
||||
>>> output = buffer_sample(buffer, count, head)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[5, 4], dtype=Float32, value=
|
||||
[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00, 3.00000000e+00],
|
||||
|
@ -99,7 +98,7 @@ class BufferSample(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype):
|
||||
def __init__(self, capacity, batch_size, buffer_shape, buffer_dtype, seed=0):
|
||||
"""Initialize BufferSample."""
|
||||
self.init_prim_io_names(inputs=["buffer"], outputs=["sample"])
|
||||
validator.check_value_type("shape of init data", buffer_shape, [tuple, list], self.name)
|
||||
|
@ -110,6 +109,7 @@ class BufferSample(PrimitiveWithInfer):
|
|||
self._n = len(buffer_shape)
|
||||
validator.check_int(self._batch_size, capacity, Rel.LE, "batchsize", self.name)
|
||||
self.add_prim_attr('capacity', capacity)
|
||||
self.add_prim_attr('seed', seed)
|
||||
buffer_elements = []
|
||||
for shape in buffer_shape:
|
||||
buffer_elements.append(reduce(lambda x, y: x * y, shape))
|
||||
|
@ -119,17 +119,16 @@ class BufferSample(PrimitiveWithInfer):
|
|||
if context.get_context('device_target') == "Ascend":
|
||||
self.add_prim_attr('device_target', "CPU")
|
||||
|
||||
def infer_shape(self, data_shape, index_shape, count_shape, head_shape):
|
||||
def infer_shape(self, data_shape, count_shape, head_shape):
|
||||
validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
|
||||
out_shapes = []
|
||||
for i in range(self._n):
|
||||
out_shapes.append((self._batch_size,) + self._buffer_shape[i])
|
||||
return tuple(out_shapes)
|
||||
|
||||
def infer_dtype(self, data_type, index_type, count_type, head_type):
|
||||
def infer_dtype(self, data_type, count_type, head_type):
|
||||
validator.check_type_name("count type", count_type, (mstype.int32), self.name)
|
||||
validator.check_type_name("head type", head_type, (mstype.int32), self.name)
|
||||
validator.check_type_name("index type", index_type, (mstype.int64, mstype.int32), self.name)
|
||||
return tuple(self._buffer_dtype)
|
||||
|
||||
class BufferAppend(PrimitiveWithInfer):
|
||||
|
|
|
@ -45,9 +45,6 @@ class RLBuffer(nn.Cell):
|
|||
self.buffer_get = P.BufferGetItem(self._capacity, shapes, types)
|
||||
self.buffer_sample = P.BufferSample(
|
||||
self._capacity, batch_size, shapes, types)
|
||||
self.dummy_tensor = Tensor(np.ones(shape=[batch_size]), ms.bool_)
|
||||
self.rnd_choice_mask = P.RandomChoiceWithMask(count=batch_size)
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
@ms_function
|
||||
def append(self, exps):
|
||||
|
@ -59,9 +56,7 @@ class RLBuffer(nn.Cell):
|
|||
|
||||
@ms_function
|
||||
def sample(self):
|
||||
index, _ = self.rnd_choice_mask(self.dummy_tensor)
|
||||
index = self.reshape(index, (self._batch_size,))
|
||||
return self.buffer_sample(self.buffer, index, self.count, self.head)
|
||||
return self.buffer_sample(self.buffer, self.count, self.head)
|
||||
|
||||
|
||||
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
|
||||
|
|
|
@ -55,16 +55,13 @@ class RLBufferSample(nn.Cell):
|
|||
def __init__(self, capcity, batch_size, shapes, types):
|
||||
super(RLBufferSample, self).__init__()
|
||||
self._capacity = capcity
|
||||
count = 5
|
||||
self.count = Parameter(Tensor(5, ms.int32), name="count")
|
||||
self.head = Parameter(Tensor(0, ms.int32), name="head")
|
||||
self.input_x = Tensor(np.ones(shape=[count]), ms.bool_)
|
||||
self.buffer_sample = P.BufferSample(self._capacity, batch_size, shapes, types)
|
||||
self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index")
|
||||
|
||||
@ms_function
|
||||
def construct(self, buffer):
|
||||
return self.buffer_sample(buffer, self.index, self.count, self.head)
|
||||
return self.buffer_sample(buffer, self.count, self.head)
|
||||
|
||||
|
||||
states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0)
|
||||
|
@ -93,14 +90,7 @@ def test_BufferSample():
|
|||
buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[
|
||||
ms.float32, ms.int32, ms.int32, ms.float32])
|
||||
ss, aa, rr, ss_ = buffer_sample(b)
|
||||
expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]]
|
||||
expect_a = [[0, 1], [4, 5], [8, 9]]
|
||||
expect_r = [[1], [1], [1]]
|
||||
expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]]
|
||||
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
|
||||
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
|
||||
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
|
||||
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
|
||||
print(ss, aa, rr, ss_)
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
|
|
|
@ -56,9 +56,7 @@ class RLBuffer(nn.Cell):
|
|||
|
||||
@ms_function
|
||||
def sample(self):
|
||||
count = self.reshape(self.count, (1,))
|
||||
index = self.randperm(count)
|
||||
return self.buffer_sample(self.buffer, index, self.count, self.head)
|
||||
return self.buffer_sample(self.buffer, self.count, self.head)
|
||||
|
||||
|
||||
s = Tensor(np.array([2, 2, 2, 2]), ms.float32)
|
||||
|
|
|
@ -55,17 +55,14 @@ class RLBufferSample(nn.Cell):
|
|||
def __init__(self, capcity, batch_size, shapes, types):
|
||||
super(RLBufferSample, self).__init__()
|
||||
self._capacity = capcity
|
||||
count = 5
|
||||
self.count = Parameter(Tensor(5, ms.int32), name="count")
|
||||
self.head = Parameter(Tensor(0, ms.int32), name="head")
|
||||
self.input_x = Tensor(np.ones(shape=[count]), ms.bool_)
|
||||
self.buffer_sample = P.BufferSample(
|
||||
self._capacity, batch_size, shapes, types)
|
||||
self.index = Parameter(Tensor([0, 2, 4], ms.int32), name="index")
|
||||
|
||||
@ms_function
|
||||
def construct(self, buffer):
|
||||
return self.buffer_sample(buffer, self.index, self.count, self.head)
|
||||
return self.buffer_sample(buffer, self.count, self.head)
|
||||
|
||||
|
||||
states = Tensor(np.arange(4*5).reshape(5, 4).astype(np.float32)/10.0)
|
||||
|
@ -94,14 +91,7 @@ def test_BufferSample():
|
|||
buffer_sample = RLBufferSample(capcity=5, batch_size=3, shapes=[(4,), (2,), (1,), (4,)], types=[
|
||||
ms.float32, ms.int32, ms.int32, ms.float32])
|
||||
ss, aa, rr, ss_ = buffer_sample(b)
|
||||
expect_s = [[0, 0.1, 0.2, 0.3], [0.8, 0.9, 1.0, 1.1], [1.6, 1.7, 1.8, 1.9]]
|
||||
expect_a = [[0, 1], [4, 5], [8, 9]]
|
||||
expect_r = [[1], [1], [1]]
|
||||
expect_s_ = [[0, 1, 2, 3], [8, 9, 10, 11], [16, 17, 18, 19]]
|
||||
np.testing.assert_almost_equal(ss.asnumpy(), expect_s)
|
||||
np.testing.assert_almost_equal(aa.asnumpy(), expect_a)
|
||||
np.testing.assert_almost_equal(rr.asnumpy(), expect_r)
|
||||
np.testing.assert_almost_equal(ss_.asnumpy(), expect_s_)
|
||||
print(ss, aa, rr, ss_)
|
||||
|
||||
|
||||
@ pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue