forked from mindspore-Ecosystem/mindspore
!10829 Modify Split Kernel for CPU
From: @shaoxiangdong Reviewed-by: Signed-off-by:
This commit is contained in:
commit
4ceb9e633c
|
@ -13,111 +13,105 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/split_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void SplitCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
|
||||
axis_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
auto output_1_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToLong(output_1_shape.size());
|
||||
}
|
||||
axis_ += 4 - SizeToLong(output_1_shape.size());
|
||||
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
|
||||
CPUKernelUtils::ExpandDimsTo4(&output_shape);
|
||||
output_shape_list_.push_back(output_shape);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SplitCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
axis_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis");
|
||||
output_num_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "output_num");
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
||||
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
CheckParam(kernel_node);
|
||||
Reshape();
|
||||
}
|
||||
|
||||
bool SplitCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt) {
|
||||
return LaunchKernel<int32_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
return LaunchKernel<int64_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) {
|
||||
return LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
return LaunchKernel<double>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_);
|
||||
template <typename T>
|
||||
void SplitCPUKernel<T>::Reshape() {
|
||||
input_size_ = 1;
|
||||
dims_current_after_axis_ = 1;
|
||||
dims_after_axis_ = 1;
|
||||
axis_step_ = input_shape_[axis_] / output_num_;
|
||||
|
||||
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
|
||||
input_size_ *= input_shape_[i];
|
||||
if (i > axis_) {
|
||||
dims_current_after_axis_ *= input_shape_[i];
|
||||
dims_after_axis_ *= input_shape_[i];
|
||||
}
|
||||
if (i == axis_) {
|
||||
dims_current_after_axis_ *= input_shape_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SplitCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto buff_size = inputs[0]->size;
|
||||
size_t dim0 = input_shape_[0];
|
||||
size_t dim1 = input_shape_[1];
|
||||
size_t dim2 = input_shape_[2];
|
||||
void SplitCPUKernel<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
CPUKernel::InitInputOutputSize(kernel_node);
|
||||
workspace_size_list_.emplace_back((sizeof(T *) * output_num_));
|
||||
}
|
||||
|
||||
if (axis_ == 3) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
for (size_t k = 0; k < dim2; ++k) {
|
||||
CopyDataToOutput(outputs, i, j, k, &input_addr, &buff_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (axis_ == 2) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
CopyDataToOutput(outputs, i, j, 0, &input_addr, &buff_size);
|
||||
}
|
||||
}
|
||||
} else if (axis_ == 1) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
CopyDataToOutput(outputs, i, 0, 0, &input_addr, &buff_size);
|
||||
}
|
||||
} else if (axis_ == 0) {
|
||||
CopyDataToOutput(outputs, 0, 0, 0, &input_addr, &buff_size);
|
||||
}
|
||||
template <typename T>
|
||||
bool SplitCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
LaunchKernel(inputs, workspace, outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SplitCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &outputs, size_t dim0, size_t dim1,
|
||||
size_t dim2, T **input_addr, size_t *buff_size) {
|
||||
for (size_t i = 0; i < output_shape_list_.size(); ++i) {
|
||||
auto output_i_shape = output_shape_list_[i];
|
||||
auto output_i_addr = reinterpret_cast<float *>(outputs[i]->addr);
|
||||
|
||||
size_t num = CPUKernelUtils::GetElementNumOnAxis(output_i_shape, axis_);
|
||||
num *= output_i_shape[axis_];
|
||||
auto pos = CPUKernelUtils::CalcOffset(output_i_shape, dim0, dim1, dim2, 0);
|
||||
auto ret = memcpy_s(output_i_addr + pos, *buff_size, *input_addr, num * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed.";
|
||||
void SplitCPUKernel<T>::LaunchSplit(const T *input, T **output, size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
int num = i % dims_current_after_axis_ / dims_after_axis_;
|
||||
int block = num / axis_step_;
|
||||
int block_pos = i / dims_current_after_axis_ * axis_step_ * dims_after_axis_ +
|
||||
num % axis_step_ * dims_after_axis_ + i % dims_after_axis_;
|
||||
output[block][block_pos] = input[i];
|
||||
}
|
||||
*input_addr += num;
|
||||
*buff_size -= num * sizeof(T);
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
return;
|
||||
}
|
||||
|
||||
void SplitCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (output_shape.size() > 4) {
|
||||
MS_LOG(EXCEPTION) << "Output dims is " << output_shape.size() << ", but SplitCPUKernel only support 4d or lower.";
|
||||
template <typename T>
|
||||
void SplitCPUKernel<T>::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T **output = reinterpret_cast<T **>(workspace[0]->addr);
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
output[i] = reinterpret_cast<T *>(outputs[i]->addr);
|
||||
}
|
||||
size_t size = static_cast<size_t>(inputs[0]->size / sizeof(T));
|
||||
LaunchSplit(input, output, size);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SplitCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
int64_t dims = SizeToLong(input_shape_.size());
|
||||
int64_t output_num = SizeToLong(AnfAlgo::GetOutputTensorNum(kernel_node));
|
||||
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SplitCPUKernel needs 1 input.";
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Split needs 1 input.";
|
||||
}
|
||||
if (dims == 0) {
|
||||
MS_LOG(EXCEPTION) << "Input dims is " << dims << ", scalar is not supported.";
|
||||
}
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "Attr axis_ " << axis_ << " must be in " << -dims << "~" << dims;
|
||||
}
|
||||
if (axis_ < 0) {
|
||||
axis_ += SizeToInt(input_shape_.size());
|
||||
}
|
||||
if (output_num_ > SizeToInt(input_shape_[axis_])) {
|
||||
MS_LOG(EXCEPTION) << "Attr output_num " << output_num_ << " must less than " << input_shape_[axis_];
|
||||
}
|
||||
if (output_num_ != output_num) {
|
||||
MS_LOG(EXCEPTION) << "Output num is " << output_num << ", but need " << output_num_;
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -17,14 +17,16 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPLIT_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SplitCPUKernel : public CPUKernel {
|
||||
public:
|
||||
SplitCPUKernel() : axis_(0) {}
|
||||
SplitCPUKernel() = default;
|
||||
~SplitCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
@ -32,26 +34,46 @@ class SplitCPUKernel : public CPUKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
|
||||
private:
|
||||
static void CheckParam(const CNodePtr &kernel_node);
|
||||
template <typename T>
|
||||
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
|
||||
T **output_addr, size_t *buff_size);
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
void Reshape();
|
||||
void LaunchSplit(const T *input, T **output, size_t size);
|
||||
int64_t axis_;
|
||||
int64_t output_num_;
|
||||
int64_t axis_step_;
|
||||
|
||||
size_t input_size_;
|
||||
size_t dims_after_axis_;
|
||||
size_t dims_current_after_axis_;
|
||||
|
||||
std::vector<std::vector<size_t>> output_shape_list_;
|
||||
std::vector<size_t> input_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SplitCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SplitCPUKernel);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SplitCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SplitCPUKernel, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SplitCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SplitCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
SplitCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SplitCPUKernel, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -80,6 +80,164 @@ def test_out2_axis1neg():
|
|||
assert np.allclose(outputs[1].asnumpy()[0, :, :], [[3., 4., 5.], [9., 10., 11.]])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_out_float32():
|
||||
op = P.Split(5, 2)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.arange(192).astype(np.float32).reshape((2, 2, 2, 2, 2, 6)))
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2.])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3., 4., 5.])
|
||||
|
||||
op = P.Split(5, 3)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.])
|
||||
assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_out_float64():
|
||||
op = P.Split(5, 2)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.arange(192).astype(np.float64).reshape((2, 2, 2, 2, 2, 6)))
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2.])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3., 4., 5.])
|
||||
|
||||
op = P.Split(5, 3)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.])
|
||||
assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_out_float16():
|
||||
op = P.Split(-1, 2)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.arange(320).astype(np.float16).reshape((2, 2, 2, 2, 2, 10)))
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2., 3., 4.])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [5., 6., 7., 8., 9.])
|
||||
|
||||
op = P.Split(-1, 5)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.])
|
||||
assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.])
|
||||
assert np.allclose(outputs[3].asnumpy()[0, 0, 0, 0, 0, :], [6., 7.])
|
||||
assert np.allclose(outputs[4].asnumpy()[0, 0, 0, 0, 0, :], [8., 9.])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_out_int32():
|
||||
op = P.Split(5, 2)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.arange(192).astype(np.int32).reshape((2, 2, 2, 2, 2, 6)))
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3, 4, 5])
|
||||
|
||||
op = P.Split(5, 3)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, 0, :], [96, 97])
|
||||
assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, 0, :], [98, 99])
|
||||
assert np.allclose(outputs[2].asnumpy()[1, 0, 0, 0, 0, :], [100, 101])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_out_int64():
|
||||
op = P.Split(5, 2)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.arange(192).astype(np.int64).reshape((2, 2, 2, 2, 2, 6)))
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3, 4, 5])
|
||||
|
||||
op = P.Split(5, 3)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, 0, :], [96, 97])
|
||||
assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, 0, :], [98, 99])
|
||||
assert np.allclose(outputs[2].asnumpy()[1, 0, 0, 0, 0, :], [100, 101])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_out_uint32():
|
||||
op = P.Split(-1, 2)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
|
||||
input_x = Tensor(np.arange(320).astype(np.uint32).reshape((2, 2, 2, 2, 2, 10)))
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2, 3, 4])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [5, 6, 7, 8, 9])
|
||||
|
||||
op = P.Split(-1, 5)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 1, 1, :], [310, 311])
|
||||
assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 1, 1, :], [312, 313])
|
||||
assert np.allclose(outputs[2].asnumpy()[1, 1, 1, 1, 1, :], [314, 315])
|
||||
assert np.allclose(outputs[3].asnumpy()[1, 1, 1, 1, 1, :], [316, 317])
|
||||
assert np.allclose(outputs[4].asnumpy()[1, 1, 1, 1, 1, :], [318, 319])
|
||||
|
||||
op = P.Split(-2, 2)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, :, 0], [0])
|
||||
assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, :, 1], [11])
|
||||
assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, :, 2], [162])
|
||||
assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, :, 3], [173])
|
||||
assert np.allclose(outputs[0].asnumpy()[1, 1, 0, 0, :, 4], [244])
|
||||
assert np.allclose(outputs[1].asnumpy()[1, 1, 0, 0, :, 5], [255])
|
||||
assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 0, :, 6], [286])
|
||||
assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 0, :, 7], [297])
|
||||
assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 1, :, 8], [308])
|
||||
assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 1, :, 9], [319])
|
||||
|
||||
op = P.Split(-1, 1)
|
||||
op_wrapper = OpNetWrapper(op)
|
||||
input_x = Tensor(np.arange(1).astype(np.uint32))
|
||||
outputs = op_wrapper(input_x)
|
||||
|
||||
assert np.allclose(outputs[0].asnumpy(), [0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_out1_axis0()
|
||||
test_out2_axis2()
|
||||
|
|
Loading…
Reference in New Issue