!49631 support CPU type of operator

Merge pull request !49631 from zong_shuai/support_cpu_type
This commit is contained in:
i-robot 2023-03-06 01:17:32 +00:00 committed by Gitee
commit 1bfe24bd98
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 73 additions and 15 deletions

View File

@ -370,12 +370,22 @@ std::shared_ptr<CpuKernelFunc> SpecializeArithLogComplexFunc() {
using ArithLogicCpuFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFuncCreator>>> kernel_attr_lists = {
{kLess,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{kEqual,
@ -491,7 +501,9 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFunc
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{kLessEqual,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>},

View File

@ -501,7 +501,10 @@ template <typename T>
void Abs(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
if constexpr ((std::is_same_v<T, uint8_t>) || (std::is_same_v<T, uint16_t>) || (std::is_same_v<T, uint32_t>) ||
(std::is_same_v<T, uint64_t>)) {
MS_LOG(EXCEPTION) << "'Abs' cannot be instantiated.";
auto ret_code = memcpy_s(out, size * sizeof(T), in, size * sizeof(T));
if (ret_code != EOK) {
MS_LOG(EXCEPTION) << "For Abs, Failed to copy data, memcpy_s errorno: " << ret_code;
}
} else {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
@ -750,6 +753,7 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressP
{prim::kPrimTanh->name(), Tanh<T>},
{prim::kPrimAtanh->name(), Atanh<T>},
{prim::kPrimInv->name(), Inv<T>},
{prim::kPrimAbs->name(), Abs<T>},
{prim::kPrimSign->name(), ComplexSign<T>},
{prim::kPrimLog->name(), ComplexLog<T>},
{prim::kPrimExp->name(), ComplexExp<T>},
@ -1035,8 +1039,14 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kAbs,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kSqrt,

View File

@ -93,9 +93,6 @@ void MaximumCpuKernelMod::InitInputTensorAndScalar(size_t max_input_shape_size)
}
void MaximumCpuKernelMod::InitInputTensors(TypeId input_x_dtype, TypeId input_y_dtype) {
if (input_x_dtype == kNumberTypeBool && input_y_dtype == kNumberTypeBool) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor types can not be both bool.";
}
// Check if the shape needs to be broadcast
need_broadcast_ = IsBroadcast();
if (need_broadcast_) {
@ -241,6 +238,14 @@ void MaximumCpuKernelMod::BroadcastArithTensors(const T *input_x, const T *input
}
const std::vector<std::pair<KernelAttr, MaximumCpuKernelMod::KernelRunFunc>> &MaximumCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, MaximumCpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&MaximumCpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&MaximumCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&MaximumCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&MaximumCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&MaximumCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
@ -249,6 +254,8 @@ const std::vector<std::pair<KernelAttr, MaximumCpuKernelMod::KernelRunFunc>> &Ma
&MaximumCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&MaximumCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&MaximumCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&MaximumCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),

View File

@ -46,6 +46,9 @@ class MaximumCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::OpSupport(); }
private:
bool IsBroadcast() const;
size_t Index(const size_t &index, const size_t &dim) const;

View File

@ -444,6 +444,13 @@ __global__ void AbsKernel(const half *input, half *output, const size_t count) {
return;
}
template <>
__global__ void AbsKernel(const bool *input, bool *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = input[i];
}
return;
}
template <>
__global__ void AbsKernel(const uint8_t *input, uint8_t *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = input[i];
@ -1423,6 +1430,8 @@ template CUDA_LIB_EXPORT void Sign<uint64_t>(const uint64_t *input, uint64_t *ou
cudaStream_t cuda_stream);
// complex64
template CUDA_LIB_EXPORT void Abs<Complex<float>>(const Complex<float> *input, Complex<float> *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Real<float>(const Complex<float> *input, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Imag<float>(const Complex<float> *input, float *output, const size_t count,
@ -1461,6 +1470,8 @@ template CUDA_LIB_EXPORT void Log1p<Complex<float>>(const Complex<float> *input,
const size_t count, cudaStream_t cuda_stream);
// complex128
template CUDA_LIB_EXPORT void Abs<Complex<double>>(const Complex<double> *input, Complex<double> *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Real<double>(const Complex<double> *input, double *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Imag<double>(const Complex<double> *input, double *output, const size_t count,

View File

@ -442,7 +442,21 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&UnaryOpGpuKernelMod::LaunchKernel<utils::Complex<double>>}}},
{kAbs,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&UnaryOpGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&UnaryOpGpuKernelMod::LaunchKernel<char>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&UnaryOpGpuKernelMod::LaunchKernel<unsigned char>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&UnaryOpGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&UnaryOpGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&UnaryOpGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&UnaryOpGpuKernelMod::LaunchKernel<utils::Complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&UnaryOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&UnaryOpGpuKernelMod::LaunchKernel<float>},
@ -576,7 +590,7 @@ bool UnaryOpGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
{kSin, Sin<T>}, {kCos, Cos<T>}, {kACos, ACos<T>}, {kAcosh, Acosh<T>},
{kAsin, Asin<T>}, {kAsinh, Asinh<T>}, {kSquare, SquareOpt<T>}, {kReciprocal, ReciprocalOpt<T>},
{kRsqrt, Rsqrt<T>}, {kSign, Sign<T>}, {kAtan, Atan<T>}, {kSinh, Sinh<T>},
{kExpm1, Expm1<T>}, {kLog1p, Log1p<T>}};
{kExpm1, Expm1<T>}, {kLog1p, Log1p<T>}, {kAbs, Abs<T>}};
copy(func_map_complex.begin(), func_map_complex.end(), inserter(func_map, func_map.begin()));
} else {
std::map<std::string, std::function<void(const T *, T *, const size_t, cudaStream_t)>> func_map_normal = {

View File

@ -59,7 +59,8 @@ class AbsInfer : public abstract::OpInferBase {
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types_with_complex_and_bool,
prim->name());
return x_type;
}

View File

@ -39,7 +39,7 @@ TypePtr LessInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_bool, prim->name());
return std::make_shared<TensorType>(kBool);
}
} // namespace

View File

@ -51,7 +51,7 @@ TypePtr LessEqualInferType(const PrimitivePtr &prim, const std::vector<AbstractB
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_bool, prim->name());
return std::make_shared<TensorType>(kBool);
}
} // namespace

View File

@ -75,7 +75,7 @@ class MaximumInfer : public abstract::OpInferBase {
<< type_x->ToString() << ", " << type_y->ToString() << "].";
}
}
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex_and_bool, prim->name());
return type_x;
}
};