forked from mindspore-Ecosystem/mindspore
add some gpu ops after removing akg
This commit is contained in:
parent
7c393c0375
commit
de63ee4690
|
@ -50,6 +50,75 @@ struct EqualFunc <float> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GreaterEqualFunc {
|
||||
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs >= rhs ? true : false; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GreaterEqualFunc <half> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ?
|
||||
true : (__half2float(lhs) > __half2float(rhs) ? true : false);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GreaterEqualFunc <float> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
|
||||
return std::abs(lhs - rhs) < 1e-9 ? true : (lhs > rhs ? true : false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LessEqualFunc {
|
||||
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs <= rhs ? true : false; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LessEqualFunc <half> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ?
|
||||
true : (__half2float(lhs) < __half2float(rhs) ? true : false);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LessEqualFunc <float> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
|
||||
return std::abs(lhs - rhs) < 1e-9 ? true : (lhs < rhs ? true : false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NotEqualFunc {
|
||||
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs == rhs ? false : true; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NotEqualFunc <half> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? false : true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NotEqualFunc <float> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
|
||||
return std::abs(lhs - rhs) < 1e-9 ? false : true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LogicalAndFunc {
|
||||
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs && rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LogicalOrFunc {
|
||||
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs || rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinimumFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; }
|
||||
|
@ -329,6 +398,16 @@ void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *
|
|||
return ElewiseCmpKernel<T, LessFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_EQUAL:
|
||||
return ElewiseCmpKernel<T, EqualFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_GREATER_EQUAL:
|
||||
return ElewiseCmpKernel<T, GreaterEqualFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_LESS_EQUAL:
|
||||
return ElewiseCmpKernel<T, LessEqualFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_NOT_EQUAL:
|
||||
return ElewiseCmpKernel<T, NotEqualFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_LOGICAL_AND:
|
||||
return ElewiseCmpKernel<T, LogicalAndFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_LOGICAL_OR:
|
||||
return ElewiseCmpKernel<T, LogicalOrFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -348,7 +427,10 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t
|
|||
cudaStream_t stream);
|
||||
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int16_t *x0, const int16_t *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
// Element-wise ArithMetic
|
||||
template <typename T, typename Func>
|
||||
__global__ void ElewiseArithKernel(const int nums, const T *x0, const T *x1, T *y) {
|
||||
|
@ -426,7 +508,10 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8
|
|||
cudaStream_t stream);
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int64_t *x0, const int64_t *x1, int64_t *y,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int16_t *x0, const int16_t *x1, int16_t *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
// Broadcast comparison
|
||||
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
|
||||
|
||||
|
@ -489,6 +574,31 @@ void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t>
|
|||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_GREATER_EQUAL:
|
||||
return BroadcastCmpKernel<T, GreaterEqualFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_LESS_EQUAL:
|
||||
return BroadcastCmpKernel<T, LessEqualFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_NOT_EQUAL:
|
||||
return BroadcastCmpKernel<T, NotEqualFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_LOGICAL_AND:
|
||||
return BroadcastCmpKernel<T, LogicalAndFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_LOGICAL_OR:
|
||||
return BroadcastCmpKernel<T, LogicalOrFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -515,7 +625,12 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector
|
|||
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int64_t *x0,
|
||||
const int64_t *x1, bool *y, cudaStream_t stream);
|
||||
|
||||
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int16_t *x0,
|
||||
const int16_t *x1, bool *y, cudaStream_t stream);
|
||||
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
|
||||
const bool *x1, bool *y, cudaStream_t stream);
|
||||
// Broadcast Arithmetic
|
||||
template <typename T, typename Func>
|
||||
__global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
|
||||
|
@ -662,7 +777,12 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect
|
|||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int64_t *x0,
|
||||
const int64_t *x1, int64_t *y, cudaStream_t stream);
|
||||
|
||||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int16_t *x0,
|
||||
const int16_t *x1, int16_t *y, cudaStream_t stream);
|
||||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
|
||||
const bool *x1, bool *y, cudaStream_t stream);
|
||||
// BroadcastTo
|
||||
template <typename T>
|
||||
__global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t i2, const size_t i3, const size_t o0,
|
||||
|
|
|
@ -41,6 +41,11 @@ enum BroadcastOpType {
|
|||
BROADCAST_TYPE_MOD = 15,
|
||||
BROADCAST_TYPE_FLOORMOD = 16,
|
||||
BROADCAST_TYPE_ATAN2 = 17,
|
||||
BROADCAST_TYPE_GREATER_EQUAL = 18,
|
||||
BROADCAST_TYPE_LESS_EQUAL = 19,
|
||||
BROADCAST_TYPE_NOT_EQUAL = 20,
|
||||
BROADCAST_TYPE_LOGICAL_AND = 21,
|
||||
BROADCAST_TYPE_LOGICAL_OR = 22,
|
||||
BROADCAST_TYPE_INVALID = 0xffffffff,
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
struct LogicalNotFunc {
|
||||
__device__ __host__ __forceinline__ bool operator()(const T &x) { return !x; }
|
||||
};
|
||||
|
||||
template <typename T, typename Func>
|
||||
__global__ void LogicalNotKernel(const int nums, const T *x, bool *y) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
|
||||
y[pos] = Func()(x[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LogicalNotImpl(const int &nums, const T *x, bool *y, cudaStream_t stream) {
|
||||
return LogicalNotKernel<T, LogicalNotFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x, y);
|
||||
}
|
||||
|
||||
template void LogicalNotImpl(const int &nums, const bool *x, bool *y, cudaStream_t stream);
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOGICAL_NOT_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOGICAL_NOT_H_
|
||||
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void LogicalNotImpl(const int &nums, const T *x, bool *y, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOGICAL_NOT_H_
|
|
@ -64,6 +64,17 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BroadcastOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, double)
|
||||
|
||||
// fp32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -126,6 +137,18 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
NotEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, float)
|
||||
|
||||
// fp16
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -188,6 +211,18 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BroadcastOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
NotEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, half)
|
||||
|
||||
// int32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -235,6 +270,16 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int)
|
||||
|
||||
// int64
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -279,6 +324,16 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
BroadcastOpGpuKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int64_t)
|
||||
|
||||
// int8
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -287,6 +342,12 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int8_t)
|
||||
|
||||
// uint8
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -295,5 +356,44 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, uint8_t)
|
||||
|
||||
// int16
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int16_t)
|
||||
|
||||
// bool
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, bool)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, bool)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, bool)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,6 +133,11 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
{"Greater", BROADCAST_TYPE_GREATER},
|
||||
{"Less", BROADCAST_TYPE_LESS},
|
||||
{"Equal", BROADCAST_TYPE_EQUAL},
|
||||
{"GreaterEqual", BROADCAST_TYPE_GREATER_EQUAL},
|
||||
{"LessEqual", BROADCAST_TYPE_LESS_EQUAL},
|
||||
{"NotEqual", BROADCAST_TYPE_NOT_EQUAL},
|
||||
{"LogicalAnd", BROADCAST_TYPE_LOGICAL_AND},
|
||||
{"LogicalOr", BROADCAST_TYPE_LOGICAL_OR},
|
||||
};
|
||||
|
||||
auto iter = kBroadcastCmpTypeMap.find(kernel_name);
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/logical_not_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
LogicalNotGpuKernel, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LOGICAL_NOT_GPU_KERNEL_H
|
||||
#define MINDSPORE_LOGICAL_NOT_GPU_KERNEL_H
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/logical_not_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class LogicalNotGpuKernel : public GpuKernel {
|
||||
public:
|
||||
LogicalNotGpuKernel() { ResetResource(); }
|
||||
~LogicalNotGpuKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto output_addr = GetDeviceAddress<bool>(outputs, 0);
|
||||
LogicalNotImpl(input_num_, input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
input_num_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<size_t>());
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_num_ = 1;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_num_ * sizeof(T));
|
||||
output_size_list_.push_back(input_num_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_num_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -31,6 +31,7 @@ class NetEqual(Cell):
|
|||
def construct(self, x, y):
|
||||
return self.Equal(x, y)
|
||||
|
||||
|
||||
class NetEqualDynamic(Cell):
|
||||
def __init__(self):
|
||||
super(NetEqualDynamic, self).__init__()
|
||||
|
@ -42,6 +43,7 @@ class NetEqualDynamic(Cell):
|
|||
y_conv = self.conv(y)
|
||||
return self.Equal(x_conv, y_conv)
|
||||
|
||||
|
||||
class NetNotEqual(Cell):
|
||||
def __init__(self):
|
||||
super(NetNotEqual, self).__init__()
|
||||
|
@ -50,6 +52,7 @@ class NetNotEqual(Cell):
|
|||
def construct(self, x, y):
|
||||
return self.NotEqual(x, y)
|
||||
|
||||
|
||||
class NetGreaterEqual(Cell):
|
||||
def __init__(self):
|
||||
super(NetGreaterEqual, self).__init__()
|
||||
|
@ -69,12 +72,12 @@ def test_equal():
|
|||
expect0 = np.equal(x0_np, y0_np)
|
||||
x1_np = np.array([0, 1, 3]).astype(np.float32)
|
||||
x1 = Tensor(x1_np)
|
||||
y1_np = np.array([0, 1, -3]).astype(np.float32)
|
||||
y1_np = np.array([0]).astype(np.float32)
|
||||
y1 = Tensor(y1_np)
|
||||
expect1 = np.equal(x1_np, y1_np)
|
||||
x2_np = np.array([0, 1, 3]).astype(np.int32)
|
||||
x2 = Tensor(x2_np)
|
||||
y2_np = np.array([0, 1, -3]).astype(np.int32)
|
||||
y2_np = np.array([0]).astype(np.int32)
|
||||
y2 = Tensor(y2_np)
|
||||
expect2 = np.equal(x2_np, y2_np)
|
||||
x3_np = np.array([0, 1, 3]).astype(np.int16)
|
||||
|
@ -93,74 +96,45 @@ def test_equal():
|
|||
y5 = Tensor(y5_np)
|
||||
expect5 = np.equal(x5_np, y5_np)
|
||||
x6_np = np.array([0, 1, 4]).astype(np.int8)
|
||||
x6 = Tensor(x4_np)
|
||||
x6 = Tensor(x6_np)
|
||||
y6_np = np.array([0, 1, 3]).astype(np.int8)
|
||||
y6 = Tensor(y4_np)
|
||||
y6 = Tensor(y6_np)
|
||||
expect6 = np.equal(x6_np, y6_np)
|
||||
x7_np = np.array([0, 1, 4]).astype(np.int64)
|
||||
x7 = Tensor(x4_np)
|
||||
x7 = Tensor(x7_np)
|
||||
y7_np = np.array([0, 1, 3]).astype(np.int64)
|
||||
y7 = Tensor(y4_np)
|
||||
y7 = Tensor(y7_np)
|
||||
expect7 = np.equal(x7_np, y7_np)
|
||||
x8_np = np.array([0, 1, 4]).astype(np.float16)
|
||||
x8 = Tensor(x4_np)
|
||||
x8 = Tensor(x8_np)
|
||||
y8_np = np.array([0, 1, 3]).astype(np.float16)
|
||||
y8 = Tensor(y4_np)
|
||||
y8 = Tensor(y8_np)
|
||||
expect8 = np.equal(x8_np, y8_np)
|
||||
x9_np = np.array([0, 1, 4]).astype(np.float64)
|
||||
x9 = Tensor(x9_np)
|
||||
y9_np = np.array([0, 1, 3]).astype(np.float64)
|
||||
y9 = Tensor(y9_np)
|
||||
expect9 = np.equal(x9_np, y9_np)
|
||||
|
||||
x = [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9]
|
||||
y = [y0, y1, y2, y3, y4, y5, y6, y7, y8, y9]
|
||||
expect = [expect0, expect1, expect2, expect3, expect4, expect5, expect6, expect7, expect8, expect9]
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
equal = NetEqual()
|
||||
output0 = equal(x0, y0)
|
||||
assert np.all(output0.asnumpy() == expect0)
|
||||
assert output0.shape == expect0.shape
|
||||
output1 = equal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = equal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
output3 = equal(x3, y3)
|
||||
assert np.all(output3.asnumpy() == expect3)
|
||||
assert output3.shape == expect3.shape
|
||||
output4 = equal(x4, y4)
|
||||
assert np.all(output4.asnumpy() == expect4)
|
||||
assert output4.shape == expect4.shape
|
||||
output5 = equal(x5, y5)
|
||||
assert np.all(output5.asnumpy() == expect5)
|
||||
assert output5.shape == expect5.shape
|
||||
|
||||
|
||||
for i, xi in enumerate(x):
|
||||
output = equal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
equal = NetEqual()
|
||||
output0 = equal(x0, y0)
|
||||
assert np.all(output0.asnumpy() == expect0)
|
||||
assert output0.shape == expect0.shape
|
||||
output1 = equal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = equal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
output3 = equal(x3, y3)
|
||||
assert np.all(output3.asnumpy() == expect3)
|
||||
assert output3.shape == expect3.shape
|
||||
output4 = equal(x4, y4)
|
||||
assert np.all(output4.asnumpy() == expect4)
|
||||
assert output4.shape == expect4.shape
|
||||
output5 = equal(x5, y5)
|
||||
assert np.all(output5.asnumpy() == expect5)
|
||||
assert output5.shape == expect5.shape
|
||||
output6 = equal(x6, y6)
|
||||
assert np.all(output6.asnumpy() == expect6)
|
||||
assert output6.shape == expect6.shape
|
||||
output7 = equal(x7, y7)
|
||||
assert np.all(output7.asnumpy() == expect7)
|
||||
assert output7.shape == expect7.shape
|
||||
output8 = equal(x8, y8)
|
||||
assert np.all(output8.asnumpy() == expect8)
|
||||
assert output8.shape == expect8.shape
|
||||
|
||||
for i, xi in enumerate(x):
|
||||
output = equal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -178,44 +152,42 @@ def test_notequal():
|
|||
x3 = Tensor(np.array([[False, True], [True, False]]).astype(bool))
|
||||
y3 = Tensor(np.array([[True, False]]).astype(bool))
|
||||
expect3 = np.array([[True, True], [False, False]])
|
||||
x4 = Tensor(np.array([[1.2, 1], [1, 0]]).astype(np.float16))
|
||||
y4 = Tensor(np.array([[1, 2]]).astype(np.float16))
|
||||
expect4 = np.array([[True, True], [False, True]])
|
||||
x5 = Tensor(np.array([[2, 1], [1, 0]]).astype(np.int64))
|
||||
y5 = Tensor(np.array([[1, 2]]).astype(np.int64))
|
||||
expect5 = np.array([[True, True], [False, True]])
|
||||
x6 = Tensor(np.array([[2, 1], [1, 0]]).astype(np.int32))
|
||||
y6 = Tensor(np.array([[1, 2], [1, 2]]).astype(np.int32))
|
||||
expect6 = np.array([[True, True], [False, True]])
|
||||
|
||||
x = [x0, x1, x2, x3, x4, x5, x6]
|
||||
y = [y0, y1, y2, y3, y4, y5, y6]
|
||||
expect = [expect0, expect1, expect2, expect3, expect4, expect5, expect6]
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
notequal = NetNotEqual()
|
||||
output0 = notequal(x0, y0)
|
||||
assert np.all(output0.asnumpy() == expect0)
|
||||
assert output0.shape == expect0.shape
|
||||
output1 = notequal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = notequal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
output3 = notequal(x3, y3)
|
||||
assert np.all(output3.asnumpy() == expect3)
|
||||
assert output3.shape == expect3.shape
|
||||
for i, xi in enumerate(x):
|
||||
output = notequal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
notequal = NetNotEqual()
|
||||
output0 = notequal(x0, y0)
|
||||
assert np.all(output0.asnumpy() == expect0)
|
||||
assert output0.shape == expect0.shape
|
||||
output1 = notequal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = notequal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
output3 = notequal(x3, y3)
|
||||
assert np.all(output3.asnumpy() == expect3)
|
||||
assert output3.shape == expect3.shape
|
||||
|
||||
for i, xi in enumerate(x):
|
||||
output = notequal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_greaterqual():
|
||||
x0 = Tensor(np.array([[1.2, 1], [1, 0]]).astype(np.float32))
|
||||
y0 = Tensor(np.array([[1, 2]]).astype(np.float32))
|
||||
y0 = Tensor(np.array([[1, 2], [1, 2]]).astype(np.float32))
|
||||
expect0 = np.array([[True, False], [True, False]])
|
||||
x1 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int16))
|
||||
y1 = Tensor(np.array([[1, 2]]).astype(np.int16))
|
||||
|
@ -224,29 +196,41 @@ def test_greaterqual():
|
|||
y2 = Tensor(np.array([[1, 2]]).astype(np.uint8))
|
||||
expect2 = np.array([[True, False], [True, True]])
|
||||
|
||||
x3 = Tensor(np.array([[2, 1], [1, 2]]).astype(np.float64))
|
||||
y3 = Tensor(np.array([[1, 2]]).astype(np.float64))
|
||||
expect3 = np.array([[True, False], [True, True]])
|
||||
x4 = Tensor(np.array([[2, 1], [1, 2]]).astype(np.float16))
|
||||
y4 = Tensor(np.array([[1, 2]]).astype(np.float16))
|
||||
expect4 = np.array([[True, False], [True, True]])
|
||||
x5 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int64))
|
||||
y5 = Tensor(np.array([[1, 2]]).astype(np.int64))
|
||||
expect5 = np.array([[True, False], [True, False]])
|
||||
x6 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int32))
|
||||
y6 = Tensor(np.array([[1, 2]]).astype(np.int32))
|
||||
expect6 = np.array([[True, False], [True, False]])
|
||||
x7 = Tensor(np.array([[2, 1], [1, 1]]).astype(np.int8))
|
||||
y7 = Tensor(np.array([[1, 2]]).astype(np.int8))
|
||||
expect7 = np.array([[True, False], [True, False]])
|
||||
|
||||
x = [x0, x1, x2, x3, x4, x5, x6, x7]
|
||||
y = [y0, y1, y2, y3, y4, y5, y6, y7]
|
||||
expect = [expect0, expect1, expect2, expect3, expect4, expect5, expect6, expect7]
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
gequal = NetGreaterEqual()
|
||||
output0 = gequal(x0, y0)
|
||||
assert np.all(output0.asnumpy() == expect0)
|
||||
assert output0.shape == expect0.shape
|
||||
output1 = gequal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = gequal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
for i, xi in enumerate(x):
|
||||
output = gequal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gequal = NetGreaterEqual()
|
||||
output0 = gequal(x0, y0)
|
||||
assert np.all(output0.asnumpy() == expect0)
|
||||
assert output0.shape == expect0.shape
|
||||
output1 = gequal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = gequal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
for i, xi in enumerate(x):
|
||||
output = gequal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -36,29 +36,46 @@ class Net(Cell):
|
|||
@pytest.mark.env_onecard
|
||||
def test_lessequal():
|
||||
x = Tensor(np.array([[1, 2, 3]]).astype(np.float32))
|
||||
y = Tensor(np.array([[2]]).astype(np.float32))
|
||||
expect = [[True, True, False]]
|
||||
y = Tensor(np.array([[2, 2, 2]]).astype(np.float32))
|
||||
expect = np.array([[True, True, False]])
|
||||
x1 = Tensor(np.array([[1, 2, 3]]).astype(np.int16))
|
||||
y1 = Tensor(np.array([[2]]).astype(np.int16))
|
||||
expect = [[True, True, False]]
|
||||
expect1 = np.array([[True, True, False]])
|
||||
x2 = Tensor(np.array([[1, 2, 3]]).astype(np.uint8))
|
||||
y2 = Tensor(np.array([[2]]).astype(np.uint8))
|
||||
expect = [[True, True, False]]
|
||||
expect2 = np.array([[True, True, False]])
|
||||
x3 = Tensor(np.array([[1, 2, 3]]).astype(np.float64))
|
||||
y3 = Tensor(np.array([[2]]).astype(np.float64))
|
||||
expect3 = np.array([[True, True, False]])
|
||||
x4 = Tensor(np.array([[1, 2, 3]]).astype(np.float16))
|
||||
y4 = Tensor(np.array([[2]]).astype(np.float16))
|
||||
expect4 = np.array([[True, True, False]])
|
||||
x5 = Tensor(np.array([[1, 2, 3]]).astype(np.int64))
|
||||
y5 = Tensor(np.array([[2]]).astype(np.int64))
|
||||
expect5 = np.array([[True, True, False]])
|
||||
x6 = Tensor(np.array([[1, 2, 3]]).astype(np.int32))
|
||||
y6 = Tensor(np.array([[2, 2, 2]]).astype(np.int32))
|
||||
expect6 = np.array([[True, True, False]])
|
||||
x7 = Tensor(np.array([[1, 2, 3]]).astype(np.int8))
|
||||
y7 = Tensor(np.array([[2]]).astype(np.int8))
|
||||
expect7 = np.array([[True, True, False]])
|
||||
|
||||
x = [x, x1, x2, x3, x4, x5, x6, x7]
|
||||
y = [y, y1, y2, y3, y4, y5, y6, y7]
|
||||
expect = [expect, expect1, expect2, expect3, expect4, expect5, expect6, expect7]
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
lessequal = Net()
|
||||
output = lessequal(x, y)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
output = lessequal(x1, y1)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
output = lessequal(x2, y2)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
for i, xi in enumerate(x):
|
||||
output = lessequal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
lessequal = Net()
|
||||
output = lessequal(x, y)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
output = lessequal(x1, y1)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
output = lessequal(x2, y2)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
for i, xi in enumerate(x):
|
||||
output = lessequal(xi, y[i])
|
||||
assert np.all(output.asnumpy() == expect[i])
|
||||
assert output.shape == expect[i].shape
|
||||
print('test [%d/%d] passed!' % (i, len(x)))
|
||||
|
|
Loading…
Reference in New Issue