Add GPU Operator CompareAndBitpack
This commit is contained in:
parent
644b98c101
commit
c00fd21b17
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh"
|
||||
#include <limits>
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void CompareAndBitpack(const T *x, const T *threshold, uint8_t *output, const size_t output_num) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
|
||||
uint8_t res;
|
||||
res = (x[8 * pos] > *threshold) << 7;
|
||||
res = res | ((x[8 * pos + 1] > *threshold) << 6);
|
||||
res = res | ((x[8 * pos + 2] > *threshold) << 5);
|
||||
res = res | ((x[8 * pos + 3] > *threshold) << 4);
|
||||
res = res | ((x[8 * pos + 4] > *threshold) << 3);
|
||||
res = res | ((x[8 * pos + 5] > *threshold) << 2);
|
||||
res = res | ((x[8 * pos + 6] > *threshold) << 1);
|
||||
res = res | (x[8 * pos + 7] > *threshold);
|
||||
output[pos] = res;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void CompareAndBitpack<bool>(const bool *x, const bool *threshold,
|
||||
uint8_t *output, const size_t output_num) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
|
||||
uint8_t res;
|
||||
res = x[8 * pos] << 7;
|
||||
res = res | (x[8 * pos + 1] << 6);
|
||||
res = res | (x[8 * pos + 2] << 5);
|
||||
res = res | (x[8 * pos + 3] << 4);
|
||||
res = res | (x[8 * pos + 4] << 3);
|
||||
res = res | (x[8 * pos + 5] << 2);
|
||||
res = res | (x[8 * pos + 6] << 1);
|
||||
res = res | x[8 * pos + 7];
|
||||
output[pos] = res;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalCompareAndBitpack(const T *x, const T *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
CompareAndBitpack<<<CUDA_BLOCKS(device_id, output_num), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
x, threshold, output, output_num);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<half>(
|
||||
const half *x, const half *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<float>(
|
||||
const float *x, const float *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<double>(
|
||||
const double *x, const double *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<int8_t>(
|
||||
const int8_t *x, const int8_t *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<int16_t>(
|
||||
const int16_t *x, const int16_t *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<int32_t>(
|
||||
const int32_t *x, const int32_t *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<int64_t>(
|
||||
const int64_t *x, const int64_t *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCompareAndBitpack<bool>(
|
||||
const bool *x, const bool *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMPARE_AND_BITPACK_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMPARE_AND_BITPACK_IMPL_CUH_
|
||||
#include <curand_kernel.h>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalCompareAndBitpack(const T *x, const T *threshold, uint8_t *output, const size_t output_num,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMPARE_AND_BITPACK_IMPL_CUH_
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "include/curand.h"
|
||||
#include "mindspore/core/ops/compareAndBitpack.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kBitpack = 8;
|
||||
bool CompareAndBitpackGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
kernel_ptr_ = std::make_shared<ops::CompareAndBitpack>(base_operator->GetPrim());
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in "
|
||||
<< "[int8, int16, int32, int64, float16, float32, float64, bool], but got: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
x_unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
threshold_unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
return true;
|
||||
}
|
||||
|
||||
int CompareAndBitpackGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
auto x_long_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<size_t> x_shape;
|
||||
(void)std::transform(x_long_shape.begin(), x_long_shape.end(), std::back_inserter(x_shape), LongToSize);
|
||||
for (size_t i = 0; i < x_shape.size(); i++) {
|
||||
x_count_ *= x_shape[i];
|
||||
}
|
||||
y_count_ = x_count_ / kBitpack;
|
||||
size_t x_size = x_count_ * x_unit_size_;
|
||||
input_size_list_.emplace_back(x_size);
|
||||
size_t threshold_size = threshold_unit_size_;
|
||||
input_size_list_.emplace_back(threshold_size);
|
||||
size_t output_size = y_count_ * sizeof(uint8_t);
|
||||
output_size_list_.emplace_back(output_size);
|
||||
size_t workspace_size = 0;
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void CompareAndBitpackGpuKernelMod::ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
x_count_ = 1;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CompareAndBitpackGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *x = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *threshold = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
uint8_t *y = GetDeviceAddress<uint8_t>(outputs, kIndex0);
|
||||
CalCompareAndBitpack(x, threshold, y, y_count_, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CompareAndBitpackGpuKernelMod::CompareAndBitpackFunc>>
|
||||
CompareAndBitpackGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8),
|
||||
&CompareAndBitpackGpuKernelMod::LaunchKernel<bool>}};
|
||||
|
||||
std::vector<KernelAttr> CompareAndBitpackGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CompareAndBitpackFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CompareAndBitpack, CompareAndBitpackGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMPARE_AND_BITPACK_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMPARE_AND_BITPACK_GPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CompareAndBitpackGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
CompareAndBitpackGpuKernelMod() { ResetResource(); }
|
||||
~CompareAndBitpackGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
|
||||
void CheckCompareAndBitpackShape();
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using CompareAndBitpackFunc =
|
||||
std::function<bool(CompareAndBitpackGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
size_t x_unit_size_{1};
|
||||
size_t threshold_unit_size_{1};
|
||||
bool is_null_input_{false};
|
||||
size_t x_count_{};
|
||||
size_t y_count_{};
|
||||
void *cuda_stream_{nullptr};
|
||||
BaseOperatorPtr kernel_ptr_{nullptr};
|
||||
cudnnHandle_t cudnn_handle_{};
|
||||
curandGenerator_t curand_generator_{nullptr};
|
||||
CompareAndBitpackFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, CompareAndBitpackFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMPARE_AND_BITPACK_GPU_KERNEL_H_
|
|
@ -45,10 +45,11 @@ abstract::ShapePtr CompareAndBitpackInferShape(const PrimitivePtr &primitive,
|
|||
(void)CheckAndConvertUtils::CheckInteger("x's rank'", x_rank, kNotEqual, kShapeSize_, primitive->name());
|
||||
|
||||
// check the innermost dimension of `x`'s shape is disvisible by 8.
|
||||
(void)CheckAndConvertUtils::Check("x innermost dimension % 8", x_shape[x_rank - 1] % divisible_num, kEqual, 0,
|
||||
primitive->name());
|
||||
|
||||
ShapeVector out_shape;
|
||||
if (x_shape[x_rank - 1] != -1) {
|
||||
(void)CheckAndConvertUtils::Check("x innermost dimension % 8", x_shape[x_rank - 1] % divisible_num, kEqual, 0,
|
||||
primitive->name());
|
||||
}
|
||||
std::vector<int64_t> out_shape;
|
||||
for (int dim = 0; dim < x_rank - 1; dim = dim + 1) {
|
||||
(void)out_shape.emplace_back(x_shape[dim]);
|
||||
}
|
||||
|
|
|
@ -524,7 +524,7 @@ __all__ = [
|
|||
"Custom",
|
||||
"LuSolve",
|
||||
"CholeskyInverse",
|
||||
"Cummax",
|
||||
"Cummax"
|
||||
]
|
||||
|
||||
__sponge__ = [
|
||||
|
|
|
@ -6873,7 +6873,7 @@ class CompareAndBitpack(Primitive):
|
|||
ValueError: If the innermost dimension of `x`'s shape is not disvisible by 8.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32)
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations.math_ops import CompareAndBitpack
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class NetCompareAndBitpack(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetCompareAndBitpack, self).__init__()
|
||||
self.compare_and_bitpack = CompareAndBitpack()
|
||||
|
||||
def construct(self, x, threshold):
|
||||
return self.compare_and_bitpack(x, threshold)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compare_and_bitpack_graph():
|
||||
"""
|
||||
Feature: Compare and bitpack
|
||||
Description: test case for CompareAndBitpack of float16
|
||||
Expectation: The result are as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float16))
|
||||
threshold = Tensor(6, dtype=mstype.float16)
|
||||
net = NetCompareAndBitpack()
|
||||
output = net(x, threshold)
|
||||
out_type = output.asnumpy().dtype
|
||||
out_expect = np.array([3], dtype=np.uint8)
|
||||
diff0 = output.asnumpy() - out_expect
|
||||
error0 = np.zeros(shape=out_expect.shape)
|
||||
assert np.all(diff0 == error0)
|
||||
assert output.shape == out_expect.shape
|
||||
assert out_type == 'uint8'
|
Loading…
Reference in New Issue