!16988 upload unify cpu ops info library code

From: @chengxb7532
Reviewed-by: @linqingke,@guoqi1024
Signed-off-by: @guoqi1024
This commit is contained in:
mindspore-ci-bot 2021-05-27 16:58:24 +08:00 committed by Gitee
commit 5ed742a893
100 changed files with 2303 additions and 1173 deletions

View File

@ -40,22 +40,7 @@ class AdamCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(Adam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamCPUKernel)
MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCPUKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -40,10 +40,8 @@ class ArgmaxCPUKernel : public CPUKernel {
size_t dim_axis_;
};
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
ArgmaxCPUKernel, float);
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
ArgmaxCPUKernel, float16);
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCPUKernel, float);
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCPUKernel, float16);
} // namespace kernel
} // namespace mindspore

View File

@ -42,14 +42,8 @@ class ArgMinWithValueCPUKernel : public CPUKernel {
size_t dim_axis_;
};
MS_REG_CPU_KERNEL_T(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
ArgMinWithValueCPUKernel, float);
MS_REG_CPU_KERNEL_T(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
ArgMinWithValueCPUKernel, float16);
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCPUKernel, float);
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCPUKernel, float16);
} // namespace kernel
} // namespace mindspore

View File

@ -63,43 +63,18 @@ class ArithmeticCPUKernel : public CPUKernel {
TypeId target_dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel, int);
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Pow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel, int);
MS_REG_CPU_KERNEL_T(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Pow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel, int);
MS_REG_CPU_KERNEL_T(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel, int);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t);
@ -139,9 +114,6 @@ MS_REG_CPU_KERNEL_T(
MS_REG_CPU_KERNEL_T(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel, int);
MS_REG_CPU_KERNEL_T(
SquaredDifference,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),

View File

@ -37,10 +37,7 @@ class BiasAddCPUKernel : public CPUKernel {
std::vector<size_t> input_shape_;
std::vector<size_t> bias_shape_;
};
MS_REG_CPU_KERNEL(
BiasAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BiasAddCPUKernel);
MS_REG_CPU_KERNEL(BiasAdd, KernelAttr(), BiasAddCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_

View File

@ -36,8 +36,7 @@ class BiasAddGradCPUKernel : public CPUKernel {
private:
std::vector<size_t> input_shape_;
};
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BiasAddGradCPUKernel);
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr(), BiasAddGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIASADDGRADCPUKERNEL_H_

View File

@ -34,8 +34,8 @@ void Cast(const S *in, T *out, size_t size) {
template <typename S, typename T>
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
source_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
target_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
source_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
}
template <typename S, typename T>

View File

@ -35,309 +35,166 @@ class CastCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
TypeId source_dtype{kTypeUnknown};
TypeId target_dtype{kTypeUnknown};
TypeId source_dtype_{kTypeUnknown};
TypeId target_dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
bool, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
bool, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
bool, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
bool, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
bool, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
bool, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
bool, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
bool, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
bool, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
bool, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
bool, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
bool, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, float16, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, float16, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, float16, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
float16, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16),
CastCPUKernel, float16, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
CastCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
CastCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8),
CastCPUKernel, float16, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, float16, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, float16, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, float16, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
float16, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, float, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, float, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, float, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
float, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16),
CastCPUKernel, float, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
CastCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
CastCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8),
CastCPUKernel, float, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, float, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, float, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, float, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
float, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, double, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, double, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, double, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
double, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16),
CastCPUKernel, double, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
CastCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
CastCPUKernel, double, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8),
CastCPUKernel, double, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, double, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, double, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, double, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
double, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
int8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
int8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
int8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int8_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, int16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, int16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, int16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int16_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, int32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, int32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, int32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int32_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, int64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, int64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, int64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int64_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
uint8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
uint8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
uint8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint8_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, uint16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, uint16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, uint16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint16_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, uint32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, uint32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, uint32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint32_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, uint64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, uint64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, uint64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint64_t, bool);
} // namespace kernel
} // namespace mindspore

View File

@ -39,36 +39,16 @@ class ConcatCPUKernel : public CPUKernel {
CNodeWeakPtr node_wpt_;
};
MS_REG_CPU_KERNEL_T(
Concat, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConcatCPUKernel, float);
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ConcatCPUKernel, int8_t)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ConcatCPUKernel, int16_t)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ConcatCPUKernel, int)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ConcatCPUKernel, int64_t)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ConcatCPUKernel, uint8_t)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
ConcatCPUKernel, uint16_t)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
ConcatCPUKernel, uint32_t)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
ConcatCPUKernel, uint64_t)
MS_REG_CPU_KERNEL_T(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ConcatCPUKernel, bool)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, float);
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int8_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int16_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int32_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int64_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint8_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint16_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint32_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint64_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, bool)
} // namespace kernel
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include <vector>
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "ir/anf.h"
using mindspore::kernel::Address;

View File

@ -17,12 +17,14 @@
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include <memory>
#include <set>
#include <string>
#include "runtime/device/kernel_info.h"
namespace mindspore {
namespace kernel {
const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose", "Unpack", "AddN"};
CPUKernelFactory &CPUKernelFactory::GetInstance() {
static CPUKernelFactory instance;
return instance;
@ -48,6 +50,58 @@ std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_na
return nullptr;
}
void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info,
std::vector<KernelAttr> *kernel_attrs) {
auto inputs_ptr = op_info->inputs_ptr();
auto outputs_ptr = op_info->outputs_ptr();
auto first_input_dtypes = inputs_ptr[0]->dtypes();
auto input_formats = inputs_ptr[0]->formats();
for (size_t i = 0; i < first_input_dtypes.size(); i++) {
KernelAttr kernel_attr;
kernel_attr.AddInputAttr(kernel::DtypeToTypeId(first_input_dtypes[i]), input_formats[i]);
for (size_t j = 1; j < inputs_ptr.size(); j++) {
auto input_dtypes = inputs_ptr[j]->dtypes();
input_formats = inputs_ptr[j]->formats();
kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
}
for (size_t j = 0; j < outputs_ptr.size(); j++) {
auto output_dtypes = outputs_ptr[j]->dtypes();
auto output_formats = outputs_ptr[j]->formats();
kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
}
if (same_op_name.count(op_info->op_name()) != 0) {
kernel_attr.SetAllSameAttr(true);
}
kernel_attrs->emplace_back(kernel_attr);
}
}
void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs) {
size_t attr_size = kernel_attrs.size();
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(ERROR) << "CPUKernelFactory has not registered operator: " << kernel_name;
return;
}
if (attr_size <= iter->second.size()) {
for (size_t i = 0; i < attr_size; i++) {
auto creator = name_to_attr_creator_.find(kernel_name)->second[i].second;
attr_creators[i] = std::make_pair(kernel_attrs[i], creator);
}
} else {
MS_LOG(INFO) << "attr size is not equal creators size " << kernel_name << " attr_size = " << attr_size
<< " creator_size = " << iter->second.size();
auto single_creator = name_to_attr_creator_.find(kernel_name)->second[0].second;
for (size_t i = 0; i < attr_size; i++) {
attr_creators[i] = std::make_pair(kernel_attrs[i], single_creator);
}
}
name_to_attr_creator_[kernel_name] = attr_creators;
}
std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name,
const KernelBuildInfo &kernel_info) {
auto iter = name_to_attr_creator_.find(kernel_name);
@ -55,10 +109,18 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
return std::make_pair(false, 0);
}
auto creators = iter->second;
for (size_t index = 0; index < creators.size(); ++index) {
auto attr_creator = creators[index];
if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) {
auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) {
MS_LOG(ERROR) << "Not find op[" << kernel_name << "] in cpu";
}
kernel_attrs.clear();
SetKernelAttrs(op_info_ptr, &kernel_attrs);
kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(kernel_name, kernel_attrs);
}
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
if (CPUKernelSingleAttrCheck(kernel_attrs[index], kernel_info)) {
return std::make_pair(true, index);
}
}

View File

@ -25,6 +25,7 @@
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
namespace mindspore {
@ -36,6 +37,8 @@ class CPUKernelFactory {
static CPUKernelFactory &GetInstance();
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
void SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info, std::vector<KernelAttr> *kernel_attrs);
void UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs);
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
private:

View File

@ -46,14 +46,7 @@ class DropoutCPUKernel : public CPUKernel {
uint64_t tensor_size_ = 1;
};
MS_REG_CPU_KERNEL(
Dropout,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DropoutCPUKernel);
MS_REG_CPU_KERNEL(
Dropout,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DropoutCPUKernel);
MS_REG_CPU_KERNEL(Dropout, KernelAttr(), DropoutCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_

View File

@ -43,14 +43,7 @@ class DropoutGradCpuBwdKernel : public CPUKernel {
size_t num_count, float keep_prob);
};
MS_REG_CPU_KERNEL(
DropoutGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DropoutGradCpuBwdKernel);
MS_REG_CPU_KERNEL(
DropoutGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DropoutGradCpuBwdKernel);
MS_REG_CPU_KERNEL(DropoutGrad, KernelAttr(), DropoutGradCpuBwdKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,10 +33,7 @@ class EqualCountCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(
EqualCount,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EqualCountCPUKernel);
MS_REG_CPU_KERNEL(EqualCount, KernelAttr(), EqualCountCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -44,41 +44,17 @@ class GatherV2CPUKernel : public CPUKernel {
int64_t axis_{0};
};
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
GatherV2CPUKernel, bool);
MS_REG_CPU_KERNEL_T(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2CPUKernel, float);
MS_REG_CPU_KERNEL_T(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherV2CPUKernel, double);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
GatherV2CPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherV2CPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherV2CPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
GatherV2CPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherV2CPUKernel, uint8_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
GatherV2CPUKernel, uint16_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherV2CPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
GatherV2CPUKernel, uint64_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint8_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint16_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint64_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, float);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, double);
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, bool);
} // namespace kernel
} // namespace mindspore

View File

@ -39,76 +39,16 @@ class GatherDCPUKernel : public CPUKernel {
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
GatherDCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
GatherDCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
GatherDCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
GatherDCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
GatherDCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
GatherDCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
GatherDCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
GatherDCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
GatherDCPUKernel, bool, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeBool),
GatherDCPUKernel, bool, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, bool, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, bool, int64_t);
} // namespace kernel
} // namespace mindspore

View File

@ -40,44 +40,16 @@ class GatherDGradCPUKernel : public CPUKernel {
int32_t axis_;
};
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherDGradCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
GatherDGradCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
GatherDGradCPUKernel, int32_t, float);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
GatherDGradCPUKernel, int32_t, float16);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
GatherDGradCPUKernel, int32_t, bool);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherDGradCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
GatherDGradCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
GatherDGradCPUKernel, int64_t, float);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
GatherDGradCPUKernel, int64_t, float16);
MS_REG_CPU_KERNEL_T_S(
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
GatherDGradCPUKernel, int64_t, bool);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, float);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, float16);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, bool);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, float);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, float16);
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, bool);
} // namespace kernel
} // namespace mindspore

View File

@ -47,20 +47,7 @@ class GatherNdCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherNdCPUKernel);
MS_REG_CPU_KERNEL(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
GatherNdCPUKernel);
MS_REG_CPU_KERNEL(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherNdCPUKernel);
MS_REG_CPU_KERNEL(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherNdCPUKernel);
MS_REG_CPU_KERNEL(GatherNd, KernelAttr(), GatherNdCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -41,20 +41,15 @@ class HSigmoidCPUKernel : public CPUKernel {
uint64_t tensor_size_ = 1;
};
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
HSigmoidCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
HSigmoidCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
HSigmoidCPUKernel, int);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
HSigmoidCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
HSigmoidCPUKernel, float);
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, float);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_

View File

@ -41,29 +41,11 @@ class HSigmoidGradCPUKernel : public CPUKernel {
uint64_t tensor_size_ = 1;
};
MS_REG_CPU_KERNEL_T(
HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
HSigmoidGradCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(
HSigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
HSigmoidGradCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(
HSigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
HSigmoidGradCPUKernel, int);
MS_REG_CPU_KERNEL_T(
HSigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
HSigmoidGradCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
HSigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
HSigmoidGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, float);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_

View File

@ -41,20 +41,11 @@ class HSwishCPUKernel : public CPUKernel {
uint64_t tensor_size_ = 1;
};
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), HSwishCPUKernel,
int8_t);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
HSwishCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
HSwishCPUKernel, int);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
HSwishCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
HSwishCPUKernel, float);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, float);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_

View File

@ -41,29 +41,11 @@ class HSwishGradCPUKernel : public CPUKernel {
uint64_t tensor_size_ = 1;
};
MS_REG_CPU_KERNEL_T(
HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
HSwishGradCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(
HSwishGrad,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
HSwishGradCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(
HSwishGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
HSwishGradCPUKernel, int);
MS_REG_CPU_KERNEL_T(
HSwishGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
HSwishGradCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
HSwishGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
HSwishGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, float);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_

View File

@ -51,41 +51,7 @@ class IsFiniteCPUKernel : public CPUKernel {
TypeId input_dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
IsFiniteCPUKernel);
MS_REG_CPU_KERNEL(IsFinite, KernelAttr(), IsFiniteCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -46,25 +46,7 @@ class LayerNormCPUKernel : public CPUKernel {
size_t param_num_{1};
};
MS_REG_CPU_KERNEL(LayerNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
LayerNormCPUKernel);
MS_REG_CPU_KERNEL(LayerNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
LayerNormCPUKernel);
MS_REG_CPU_KERNEL(LayerNorm, KernelAttr(), LayerNormCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_

View File

@ -48,29 +48,7 @@ class LayerNormGradCPUKernel : public CPUKernel {
size_t param_size_{1};
};
MS_REG_CPU_KERNEL(LayerNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
LayerNormGradCPUKernel);
MS_REG_CPU_KERNEL(LayerNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
LayerNormGradCPUKernel);
MS_REG_CPU_KERNEL(LayerNormGrad, KernelAttr(), LayerNormGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_

View File

@ -75,33 +75,12 @@ class MaximumCPUKernel : public CPUKernel {
const size_t max_dims{7};
};
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MaximumCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
MaximumCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MaximumCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MaximumCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
MaximumCPUKernel, uint64_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MaximumCPUKernel, double);
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, uint64_t);
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, float);
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, double);
} // namespace kernel
} // namespace mindspore

View File

@ -47,59 +47,7 @@ class MaximumGradCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MaximumGradCPUKernel);
MS_REG_CPU_KERNEL(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
MaximumGradCPUKernel);
MS_REG_CPU_KERNEL(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MaximumGradCPUKernel);
MS_REG_CPU_KERNEL(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MaximumGradCPUKernel);
MS_REG_CPU_KERNEL(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
MaximumGradCPUKernel);
MS_REG_CPU_KERNEL(MaximumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MaximumGradCPUKernel);
MS_REG_CPU_KERNEL(MaximumGrad, KernelAttr(), MaximumGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MaximumGrad_CPU_KERNEL_H_

View File

@ -75,33 +75,12 @@ class MinimumCPUKernel : public CPUKernel {
const size_t max_dims{7};
};
MS_REG_CPU_KERNEL_T(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MinimumCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
MinimumCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MinimumCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MinimumCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
MinimumCPUKernel, uint64_t);
MS_REG_CPU_KERNEL_T(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MinimumCPUKernel, double);
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, float);
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, uint64_t);
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, double);
} // namespace kernel
} // namespace mindspore

View File

@ -47,59 +47,7 @@ class MinimumGradCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MinimumGradCPUKernel);
MS_REG_CPU_KERNEL(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
MinimumGradCPUKernel);
MS_REG_CPU_KERNEL(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MinimumGradCPUKernel);
MS_REG_CPU_KERNEL(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MinimumGradCPUKernel);
MS_REG_CPU_KERNEL(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
MinimumGradCPUKernel);
MS_REG_CPU_KERNEL(MinimumGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MinimumGradCPUKernel);
MS_REG_CPU_KERNEL(MinimumGrad, KernelAttr(), MinimumGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MinimumGrad_CPU_KERNEL_H_

View File

@ -64,19 +64,7 @@ class MirrorPadCPUKernel : public CPUKernel {
int num_paddings_;
};
MS_REG_CPU_KERNEL(
MirrorPad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(MirrorPad, KernelAttr(), MirrorPadCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_

View File

@ -80,20 +80,7 @@ class MirrorPadGradCPUKernel : public CPUKernel {
int num_paddings_;
};
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(MirrorPadGrad, KernelAttr(), MirrorPadGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_

View File

@ -33,10 +33,7 @@ class ConvCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(
Conv2D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConvCPUKernel);
MS_REG_CPU_KERNEL(Conv2D, KernelAttr(), ConvCPUKernel);
MS_REG_CPU_KERNEL(
Conv3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -18,6 +18,7 @@
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
@ -36,9 +37,8 @@ class MulCPUKernel : public MKLCPUKernel {
bool need_swap_{false};
};
MS_REG_CPU_KERNEL(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MulCPUKernel);
MS_REG_CPU_KERNEL(Mul, KernelAttr(), MulCPUKernel);
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, int32_t);
} // namespace kernel
} // namespace mindspore

View File

@ -38,13 +38,7 @@ class OneHotCPUKernel : public CPUKernel {
size_t axis_;
};
MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
OneHotCPUKernel);
MS_REG_CPU_KERNEL(OneHot, KernelAttr(), OneHotCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -44,39 +44,17 @@ class PackCpuFwdKernel : public CPUKernel {
std::unique_ptr<T *[]> inputs_host_;
};
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
PackCpuFwdKernel, int8_t)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
PackCpuFwdKernel, int16_t)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
PackCpuFwdKernel, int32_t)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
PackCpuFwdKernel, int64_t)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
PackCpuFwdKernel, uint8_t)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
PackCpuFwdKernel, bool)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
PackCpuFwdKernel, uint16_t)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
PackCpuFwdKernel, uint32_t)
MS_REG_CPU_KERNEL_T(Stack,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
PackCpuFwdKernel, uint64_t)
MS_REG_CPU_KERNEL_T(
Stack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PackCpuFwdKernel, float16)
MS_REG_CPU_KERNEL_T(
Stack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PackCpuFwdKernel, float)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int8_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int16_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int32_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int64_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint8_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint16_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint32_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint64_t)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, float16)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, float)
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, bool)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_PACK_CPU_KERNEL_H

View File

@ -48,11 +48,7 @@ class PadCPUKernel : public CPUKernel {
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), PadCPUKernel);
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), PadCPUKernel);
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), PadCPUKernel);
MS_REG_CPU_KERNEL(Pad, KernelAttr(), PadCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PAD_CPU_KERNEL_H_

View File

@ -41,12 +41,7 @@ class RangeCPUKernel : public CPUKernel {
int64_t delta_;
};
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), RangeCPUKernel);
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), RangeCPUKernel);
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
RangeCPUKernel);
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
RangeCPUKernel);
MS_REG_CPU_KERNEL(Range, KernelAttr(), RangeCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -41,47 +41,29 @@ class ReduceCPUKernel : public CPUKernel {
std::function<void(const T *, size_t, T *)> reduce_func_;
};
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ReduceCPUKernel, bool);
MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr(), ReduceCPUKernel, bool);
MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ReduceCPUKernel, bool);
MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr(), ReduceCPUKernel, bool);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_

View File

@ -57,24 +57,12 @@ class SplitCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL_T(
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SplitCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SplitCPUKernel, float16);
MS_REG_CPU_KERNEL_T(
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SplitCPUKernel, double);
MS_REG_CPU_KERNEL_T(Split,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SplitCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Split,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
SplitCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Split,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SplitCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, float);
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, float16);
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, double);
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, int64_t);
} // namespace kernel
} // namespace mindspore

View File

@ -40,15 +40,9 @@ class TensorAddCPUKernel : public CPUKernel {
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL_T(
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorAddCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TensorAddCPUKernel, int);
MS_REG_CPU_KERNEL_T(
Add, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
TensorAddCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, float);
} // namespace kernel
} // namespace mindspore

View File

@ -57,25 +57,7 @@ class TileCPUKernel : public CPUKernel {
size_t input_size_;
};
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel);
MS_REG_CPU_KERNEL(Tile, KernelAttr(), TileCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_

View File

@ -40,20 +40,7 @@ class TopKCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(TopK,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
TopKCPUKernel)
MS_REG_CPU_KERNEL(TopK,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32),
TopKCPUKernel)
MS_REG_CPU_KERNEL(TopK, KernelAttr(), TopKCPUKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_

View File

@ -52,36 +52,7 @@ class TransposeCPUFwdKernel : public CPUKernel {
std::unordered_map<TypeId, TypeKernel> launch_map_;
TypeKernel launch_func_;
};
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
TransposeCPUFwdKernel);
MS_REG_CPU_KERNEL(Transpose, KernelAttr(), TransposeCPUFwdKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_

View File

@ -35,7 +35,8 @@ enum KernelType : int {
RT_KERNEL,
HCCL_KERNEL,
TBE_KERNEL,
HOST_KERNEL
HOST_KERNEL,
CPU_KERNEL,
};
namespace kernel {

View File

@ -25,7 +25,7 @@
namespace mindspore {
namespace kernel {
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU };
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU = 2, kCPU };
enum OpIOType { kInput = 0, kOutput };
class OpAttr {

View File

@ -50,6 +50,7 @@ constexpr auto kAiCore = "AiCore";
constexpr auto kCUDA = "CUDA";
constexpr auto kTbe = "TBE";
constexpr auto kAkg = "AKG";
constexpr auto kCpu = "CPU";
constexpr auto kName = "name";
constexpr auto kParamType = "param_type";
constexpr auto kDtype = "dtype";
@ -90,6 +91,9 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
} else if (imply_type_string == kAiCPU) {
OpImplyType imply_type = kAICPU;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kCpu) {
OpImplyType imply_type = kCPU;
ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else {
MS_LOG(ERROR) << "Not support imply_type";
}

View File

@ -18,7 +18,11 @@
#include <string>
#include <memory>
#include <algorithm>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/kernel_compiler/oplib/opinfo.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "utils/trace_base.h"
namespace mindspore {
namespace device {
@ -35,19 +39,17 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
return false;
}
void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats,
std::vector<TypeId> *output_types) {
void GetOutputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
TypeId dtype = kTypeUnknown;
dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
output_formats->emplace_back(kOpFormat_DEFAULT);
output_types->emplace_back(dtype);
}
}
void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats,
std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) {
void GetInputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *input_types,
std::vector<size_t> *input_no_cnode_indexes) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
TypeId dtype = kTypeUnknown;
@ -57,7 +59,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
} else {
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
}
input_formats->emplace_back(kOpFormat_DEFAULT);
input_types->emplace_back(dtype);
}
}
@ -86,16 +87,13 @@ bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
return false;
}
std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
const std::vector<std::string> &output_formats,
const std::vector<TypeId> &output_types) {
int GetOutputDtypeMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &output_types) {
if (kernel_attr.GetOutputSize() != output_types.size()) {
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
<< ", actual output num:" << output_types.size();
return std::make_pair(0, 0);
return 0;
}
int data_type_matched_num = 0;
int format_matched_num = 0;
auto output_num = output_types.size();
for (size_t i = 0; i < output_num; ++i) {
if (kernel_attr.GetOutputAttr(i).first != output_types[i]) {
@ -104,27 +102,17 @@ std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr
} else {
data_type_matched_num++;
}
if (kernel_attr.GetOutputAttr(i).second != output_formats[i]) {
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetOutputAttr(i).second
<< ", actual output format:" << output_formats[i];
} else {
format_matched_num++;
}
}
return std::make_pair(data_type_matched_num, format_matched_num);
return data_type_matched_num;
}
std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
const std::vector<std::string> &input_formats,
const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
if (kernel_attr.GetInputSize() != input_types.size()) {
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
return std::make_pair(0, 0);
return 0;
}
int data_type_matched_num = 0;
int format_matched_num = 0;
auto input_num = input_types.size();
for (size_t i = 0; i < input_num; ++i) {
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
@ -132,10 +120,9 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
<< ", actual input dtype:" << input_types[i];
} else {
data_type_matched_num++;
format_matched_num++;
}
}
return std::make_pair(data_type_matched_num, format_matched_num);
return data_type_matched_num;
}
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
@ -197,12 +184,9 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
}
} // namespace
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats,
const std::vector<TypeId> &input_types, const std::vector<size_t> &input_not_cnode_indexes,
const std::vector<std::string> &infer_output_formats, const std::vector<TypeId> &infer_output_types,
const std::vector<KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes, const std::vector<TypeId> &output_types,
std::pair<bool, bool> *matched, bool strict) {
int max_type_matched_num = -1;
int max_format_matched_num = -1;
for (auto kernel_attr : kernel_attrs) {
if (kernel_attr.GetAllSame()) {
ExpandKernelAttr(kernel_node, &kernel_attr);
@ -212,32 +196,14 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
MS_LOG(DEBUG) << "Output num is not equal!";
continue;
}
std::pair<int, int> input_type_format_matched_num =
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes, strict);
std::pair<int, int> output_type_format_matched_num =
GetOutputDtypeFormatMatchedNum(kernel_attr, infer_output_formats, infer_output_types);
// Data type first
if (input_type_format_matched_num.first > max_type_matched_num) {
max_type_matched_num = input_type_format_matched_num.first;
max_format_matched_num = input_type_format_matched_num.second;
*selected_kernel_attr = kernel_attr;
} else if (input_type_format_matched_num.first == max_type_matched_num &&
input_type_format_matched_num.second > max_format_matched_num) {
max_format_matched_num = input_type_format_matched_num.second;
*selected_kernel_attr = kernel_attr;
} else if (input_type_format_matched_num.first == max_type_matched_num &&
input_type_format_matched_num.second == max_format_matched_num) {
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
*selected_kernel_attr = kernel_attr;
}
}
int input_dtype_matched_num =
GetInputDtypeFormatMatchedNum(kernel_attr, input_types, input_not_cnode_indexes, strict);
int output_dtype_matched_num = GetOutputDtypeMatchedNum(kernel_attr, output_types);
// All formats and data types matched
if (input_type_format_matched_num.first == SizeToInt(input_types.size()) &&
input_type_format_matched_num.second == SizeToInt(input_types.size())) {
if (input_dtype_matched_num == SizeToInt(input_types.size())) {
*selected_kernel_attr = kernel_attr;
matched->first = true;
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
if (output_dtype_matched_num == SizeToInt(output_types.size())) {
matched->second = true;
return true;
}
@ -249,43 +215,50 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<size_t> input_not_cnode_indexes;
std::vector<std::string> output_formats;
std::vector<std::string> selected_output_formats;
std::vector<TypeId> output_types;
std::vector<std::string> infer_output_formats;
std::vector<TypeId> infer_output_types;
std::vector<TypeId> selected_output_types;
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
auto kernel_attrs =
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
if (kernel_attrs.empty()) {
MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node)
<< "] is not support. Trace: " << trace::DumpSourceLines(kernel_node);
if (kernel_attrs.empty() || (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0)) {
MS_LOG(DEBUG) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] will get ops attr info.";
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) {
MS_LOG(ERROR) << "Not find op[" << op_name << "] in cpu";
}
kernel_attrs.clear();
kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
}
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
GetOutputInferFormatsAndDtypes(kernel_node, &infer_output_formats, &infer_output_types);
GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes);
GetOutputDtypes(kernel_node, &output_types);
KernelAttr selected_kernel_attr;
std::pair<bool, bool> matched = std::make_pair(false, false);
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) {
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
output_types, &matched, true)) {
if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") {
KernelNotSupportException(kernel_node, input_types, infer_output_types);
KernelNotSupportException(kernel_node, input_types, output_types);
}
matched = std::make_pair(false, false);
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes,
infer_output_formats, infer_output_types, &matched, false);
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes, output_types,
&matched, false);
if (!matched.first) {
KernelNotSupportException(kernel_node, input_types, infer_output_types);
KernelNotSupportException(kernel_node, input_types, output_types);
}
}
if (selected_kernel_attr.GetInputSize() > 0 &&
(matched.first || input_types.size() == input_not_cnode_indexes.size())) {
MS_LOG(INFO) << "Input format and dtype is matched";
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
for (size_t i = 0; i < selected_kernel_attr.GetInputSize(); ++i) {
input_types[SizeToInt(i)] = selected_kernel_attr.GetInputAttr(i).first;
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &selected_output_formats, &selected_output_types);
for (size_t index = 0; index < selected_kernel_attr.GetInputSize(); index++) {
input_types[index] = selected_kernel_attr.GetInputAttr(index).first;
input_formats.emplace_back(selected_kernel_attr.GetInputAttr(index).second);
}
}
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
SetKernelBuildInfo(input_formats, input_types, selected_output_formats, selected_output_types, kernel_node.get());
}
} // namespace cpu
} // namespace device

View File

@ -23,7 +23,7 @@ Examples:
from .primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, CpuRegOp, DataType
from .primitive import constexpr
from . import composite, operations, functional
from . import signature
@ -36,7 +36,7 @@ __primitive__ = [
]
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType",
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "CpuRegOp", "DataType",
"constexpr"]
__all__.extend(__primitive__)
__all__.extend(composite.__all__)

View File

@ -16,6 +16,7 @@
import platform
from .aicpu import *
from .cpu import *
if "Windows" not in platform.system():
from .akg import *
from .tbe import *

View File

@ -0,0 +1,63 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""cpu ops"""
from .cast import _cast_cpu
from .mul import _mul_cpu
from .sub import _sub_cpu
from .pow import _pow_cpu
from .real_div import _real_div_cpu
from .div import _div_cpu
from .concat import _concat_cpu
from .split import _split_cpu
from .adam import _adam_cpu
from .arg_max import _arg_max_cpu
from .arg_min_with_value import _arg_min_with_value_cpu
from .bias_add import _bias_add_cpu
from .bias_add_grad import _bias_add_grad_cpu
from .dropout import _dropout_cpu
from .dropout_grad import _dropout_grad_cpu
from .gather_d import _gather_cpu
from .gather_d_grad import _gather_d_grad_cpu
from .gather_v2 import _gather_v2_cpu
from .gather_nd import _gather_nd_cpu
from .maximum import _maximum_cpu
from .maximum_grad import _maximum_grad_cpu
from .conv2d import _conv2d_cpu
from .hsigmoid import _hsigmoid_cpu
from .hsigmoid_grad import _hsigmoid_grad_cpu
from .hswish import _hswish_cpu
from .hswish_grad import _hswish_grad_cpu
from .is_finite import _is_finite_cpu
from .layer_norm import _layer_norm_cpu
from .layer_norm_grad import _layer_norm_grad_cpu
from .minimum import _minimum_cpu
from .minimum_grad import _minimum_grad_cpu
from .equal_count import _equal_count_cpu
from .mirror_pad import _mirror_pad_cpu
from .mirror_pad_grad import _mirror_pad_grad_cpu
from .stack import _stack_cpu
from .reduce_mean import _reduce_mean_cpu
from .reduce_max import _reduce_max_cpu
from .reduce_sum import _reduce_sum_cpu
from .reduce_min import _reduce_min_cpu
from .reduce_all import _reduce_all_cpu
from .reduce_any import _reduce_any_cpu
from .transpose import _transpose_cpu
from .tile import _tile_cpu
from .top_k import _top_k_cpu
from .add import _add_cpu
from .one_hot import _one_hot_cpu
from .pad import _pad_cpu
from .range import _range_cpu

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Concat op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
adam_op_info = CpuRegOp("Adam") \
.input(0, "var", "required") \
.input(1, "m", "required") \
.input(2, "v", "required") \
.input(3, "beta1_power", "required") \
.input(4, "beta2_power", "required") \
.input(5, "lr", "required") \
.input(6, "beta1", "required") \
.input(7, "beta2", "required") \
.input(8, "epsilon", "required") \
.input(9, "gradient", "required") \
.output(0, "output0", "required") \
.output(1, "output0", "required") \
.output(2, "output0", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.get_op_info()
@op_info_register(adam_op_info)
def _adam_cpu():
"""Adam cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Add op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
add_op_info = CpuRegOp("Add") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(add_op_info)
def _add_cpu():
"""Add cpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Argmax op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
arg_max_op_info = CpuRegOp("Argmax") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(arg_max_op_info)
def _arg_max_cpu():
"""Argmax cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ArgMinWithValue op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
arg_min_with_value_op_info = CpuRegOp("ArgMinWithValue") \
.input(0, "x", "required") \
.output(0, "indice", "required") \
.output(1, "values", "required") \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.get_op_info()
@op_info_register(arg_min_with_value_op_info)
def _arg_min_with_value_cpu():
"""ArgMinWithValue cpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BiasAdd op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
bias_add_op_info = CpuRegOp("BiasAdd") \
.input(0, "x", "required") \
.input(1, "bias", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(bias_add_op_info)
def _bias_add_cpu():
"""BiasAdd cpu register"""
return

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BiasAddGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
bias_add_grad_op_info = CpuRegOp("BiasAddGrad") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(bias_add_grad_op_info)
def _bias_add_grad_cpu():
"""BiasAddGrad cpu register"""
return

View File

@ -0,0 +1,171 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cast op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
cast_op_info = CpuRegOp("Cast") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
.dtype_format(DataType.U8_Default, DataType.U32_Default) \
.dtype_format(DataType.U8_Default, DataType.U64_Default) \
.dtype_format(DataType.U8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I16_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.F64_Default) \
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U16_Default, DataType.U32_Default) \
.dtype_format(DataType.U16_Default, DataType.U64_Default) \
.dtype_format(DataType.U16_Default, DataType.I8_Default) \
.dtype_format(DataType.U16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
.dtype_format(DataType.U16_Default, DataType.F16_Default) \
.dtype_format(DataType.U16_Default, DataType.F32_Default) \
.dtype_format(DataType.U16_Default, DataType.F64_Default) \
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.U8_Default) \
.dtype_format(DataType.U32_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U32_Default, DataType.U64_Default) \
.dtype_format(DataType.U32_Default, DataType.I8_Default) \
.dtype_format(DataType.U32_Default, DataType.I16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.F16_Default) \
.dtype_format(DataType.U32_Default, DataType.F32_Default) \
.dtype_format(DataType.U32_Default, DataType.F64_Default) \
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.U8_Default) \
.dtype_format(DataType.U64_Default, DataType.U16_Default) \
.dtype_format(DataType.U64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.U64_Default, DataType.I8_Default) \
.dtype_format(DataType.U64_Default, DataType.I16_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.F16_Default) \
.dtype_format(DataType.U64_Default, DataType.F32_Default) \
.dtype_format(DataType.U64_Default, DataType.F64_Default) \
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.U8_Default) \
.dtype_format(DataType.I8_Default, DataType.U16_Default) \
.dtype_format(DataType.I8_Default, DataType.U32_Default) \
.dtype_format(DataType.I8_Default, DataType.U64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_Default, DataType.I16_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.U16_Default) \
.dtype_format(DataType.I16_Default, DataType.U32_Default) \
.dtype_format(DataType.I16_Default, DataType.U64_Default) \
.dtype_format(DataType.I16_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
.dtype_format(DataType.I16_Default, DataType.F64_Default) \
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.U16_Default) \
.dtype_format(DataType.F16_Default, DataType.U32_Default) \
.dtype_format(DataType.F16_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I8_Default) \
.dtype_format(DataType.F16_Default, DataType.I16_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F64_Default) \
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.U8_Default) \
.dtype_format(DataType.F32_Default, DataType.U16_Default) \
.dtype_format(DataType.F32_Default, DataType.U32_Default) \
.dtype_format(DataType.F32_Default, DataType.U64_Default) \
.dtype_format(DataType.F32_Default, DataType.I8_Default) \
.dtype_format(DataType.F32_Default, DataType.I16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F64_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.U8_Default) \
.dtype_format(DataType.F64_Default, DataType.U16_Default) \
.dtype_format(DataType.F64_Default, DataType.U32_Default) \
.dtype_format(DataType.F64_Default, DataType.U64_Default) \
.dtype_format(DataType.F64_Default, DataType.I8_Default) \
.dtype_format(DataType.F64_Default, DataType.I16_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default) \
.dtype_format(DataType.F64_Default, DataType.F16_Default) \
.dtype_format(DataType.F64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.U64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(cast_op_info)
def _cast_cpu():
"""Cast Cpu register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Concat op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
concat_op_info = CpuRegOp("Concat") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(concat_op_info)
def _concat_cpu():
"""Concat cpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Conv2D op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
conv2d_op_info = CpuRegOp("Conv2D") \
.input(0, "x", "required") \
.input(1, "filter", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(conv2d_op_info)
def _conv2d_cpu():
"""Conv2D cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Div op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
div_op_info = CpuRegOp("Div") \
.input(0, "x", "required") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(div_op_info)
def _div_cpu():
"""Div cpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dropout op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
dropout_op_info = CpuRegOp("Dropout") \
.input(0, "x", "required") \
.output(0, "output0", "required") \
.output(1, "output1", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(dropout_op_info)
def _dropout_cpu():
"""Dropout cpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DropoutGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
dropout_grad_op_info = CpuRegOp("DropoutGrad") \
.input(0, "x", "required") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(dropout_grad_op_info)
def _dropout_grad_cpu():
"""DropoutGrad cpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EqualCount op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
equal_count_op_info = CpuRegOp("EqualCount") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(equal_count_op_info)
def _equal_count_cpu():
"""EqualCount cpu register"""
return

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GatherD op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
gather_d_op_info = CpuRegOp("GatherD") \
.input(0, "input", "required") \
.input(1, "dim", "required") \
.input(2, "index", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
DataType.I64_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(gather_d_op_info)
def _gather_cpu():
"""GatherD cpu register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GatherDGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
gather_d_grad_op_info = CpuRegOp("GatherDGrad") \
.input(0, "index", "required") \
.input(1, "src", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(gather_d_grad_op_info)
def _gather_d_grad_cpu():
"""GatherDGrad cpu register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GatherNd op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
gather_nd_op_info = CpuRegOp("GatherNd") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(gather_nd_op_info)
def _gather_nd_cpu():
"""GatherNd cpu register"""
return

View File

@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GatherV2 op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
gather_v2_op_info = CpuRegOp("Gather") \
.input(0, "x", "required") \
.input(1, "indices", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(gather_v2_op_info)
def _gather_v2_cpu():
"""GatherV2 cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSigmoid op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
hsigmoid_op_info = CpuRegOp("HSigmoid") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(hsigmoid_op_info)
def _hsigmoid_cpu():
"""HSigmoid cpu register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSigmoidGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
hsigmoidgrad_op_info = CpuRegOp("HSigmoidGrad") \
.input(0, "y_grad") \
.input(1, "x") \
.output(0, "output") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(hsigmoidgrad_op_info)
def _hsigmoid_grad_cpu():
"""HSigmoidGrad cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSwish op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
hswish_op_info = CpuRegOp("HSwish") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(hswish_op_info)
def _hswish_cpu():
"""HSwish cpu register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSwishGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
hswish_grad_op_info = CpuRegOp("HSwishGrad") \
.input(0, "y_grad") \
.input(1, "x") \
.output(0, "output") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(hswish_grad_op_info)
def _hswish_grad_cpu():
"""HSwishGrad cpu register"""
return

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""IsFinite op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
is_finite_op_info = CpuRegOp("IsFinite") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(is_finite_op_info)
def _is_finite_cpu():
"""IsFinite cpu register"""
return

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LayerNorm op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
layer_norm_op_info = CpuRegOp("LayerNorm") \
.input(0, "x", "required") \
.input(1, "gamma", "required") \
.input(2, "beta", "required") \
.output(0, "y", "required") \
.output(1, "mean", "required") \
.output(2, "variance", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(layer_norm_op_info)
def _layer_norm_cpu():
"""LayerNorm cpu register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LayerNormGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
layer_norm_grad_op_info = CpuRegOp("LayerNormGrad") \
.input(0, "dy", "required") \
.input(1, "x", "required") \
.input(2, "variance", "required") \
.input(3, "mean", "required") \
.input(4, "gamma", "required") \
.output(0, "pd_x", "required") \
.output(1, "pd_gamma", "required") \
.output(2, "pd_beta", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(layer_norm_grad_op_info)
def _layer_norm_grad_cpu():
"""LayerNormGrad TBE register"""
return

View File

@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Maximum op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
maximum_op_info = CpuRegOp("Maximum") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(maximum_op_info)
def _maximum_cpu():
"""Maximum cpu register"""
return

View File

@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MaximumGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
maximum_grad_op_info = CpuRegOp("MaximumGrad") \
.input(0, "grads", "required") \
.input(1, "x1", "required") \
.input(2, "x2", "required") \
.output(0, "y1", "required") \
.output(1, "y2", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default,
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default,
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default,
DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(maximum_grad_op_info)
def _maximum_grad_cpu():
"""MaximumGrad cpu register"""
return

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Minimum op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
minimum_op_info = CpuRegOp("Minimum") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(minimum_op_info)
def _minimum_cpu():
"""Minimum cpu register"""
return

View File

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MinimumGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
minimum_grad_op_info = CpuRegOp("MinimumGrad") \
.input(0, "grads", "required") \
.input(1, "x1", "required") \
.input(2, "x2", "required") \
.output(0, "y1", "required") \
.output(1, "y2", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(minimum_grad_op_info)
def _minimum_grad_cpu():
"""MinimumGrad cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MirrorPad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
mirror_pad_op_info = CpuRegOp("MirrorPad") \
.input(0, "x", "required") \
.input(1, "paddings", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(mirror_pad_op_info)
def _mirror_pad_cpu():
"""MirrorPad cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MirrorPadGrad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
mirror_pad_grad_op_info = CpuRegOp("MirrorPadGrad") \
.input(0, "x", "required") \
.input(1, "paddings", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(mirror_pad_grad_op_info)
def _mirror_pad_grad_cpu():
"""MirrorPadGrad cpu register"""
return

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mul op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
mul_op_info = CpuRegOp("Mul") \
.input(0, "x", "required") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(mul_op_info)
def _mul_cpu():
"""Mul cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""OneHot op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
one_hot_op_info = CpuRegOp("OneHot") \
.input(0, "x", "required") \
.input(1, "on_value", "required") \
.input(2, "off_value", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(one_hot_op_info)
def _one_hot_cpu():
"""OneHot cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pad op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
pad_op_info = CpuRegOp("Pad") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(pad_op_info)
def _pad_cpu():
"""Pad cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Power op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
pow_op_info = CpuRegOp("Pow") \
.input(0, "x", "required") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(pow_op_info)
def _pow_cpu():
"""Pow cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Range op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
range_op_info = CpuRegOp("Range") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(range_op_info)
def _range_cpu():
"""Range cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RealDiv op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
real_div_op_info = CpuRegOp("RealDiv") \
.input(0, "x", "required") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(real_div_op_info)
def _real_div_cpu():
"""RealDiv cpu register"""
return

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceAll op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
reduce_all_op_info = CpuRegOp("ReduceAll") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(reduce_all_op_info)
def _reduce_all_cpu():
"""ReduceAll cpu register"""
return

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceAny op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
reduce_any_op_info = CpuRegOp("ReduceAny") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(reduce_any_op_info)
def _reduce_any_cpu():
"""ReduceAny cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMax op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
reduce_max_op_info = CpuRegOp("ReduceMax") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(reduce_max_op_info)
def _reduce_max_cpu():
"""ReduceMax cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMean op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
reduce_mean_op_info = CpuRegOp("ReduceMean") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(reduce_mean_op_info)
def _reduce_mean_cpu():
"""ReduceMean cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMin op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
reduce_min_op_info = CpuRegOp("ReduceMin") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(reduce_min_op_info)
def _reduce_min_cpu():
"""ReduceMin cpu register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceSum op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
reduce_sum_op_info = CpuRegOp("ReduceSum") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(reduce_sum_op_info)
def _reduce_sum_cpu():
"""ReduceSum cpu register"""
return

View File

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Split op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
split_op_info = CpuRegOp("Split") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(split_op_info)
def _split_cpu():
"""Split cpu register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Stack op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
stack_op_info = CpuRegOp("Stack") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(stack_op_info)
def _stack_cpu():
"""Stack cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sub op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
sub_op_info = CpuRegOp("Sub") \
.input(0, "x", "required") \
.input(1, "y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(sub_op_info)
def _sub_cpu():
"""Sub cpu register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tile op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
tile_op_info = CpuRegOp("Tile") \
.input(0, "x1", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(tile_op_info)
def _tile_cpu():
"""Tile cpu register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TopK op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
top_k_op_info = CpuRegOp("TopK") \
.input(0, "input", "required") \
.input(1, "k", "required") \
.output(0, "values", "required") \
.output(1, "indices", "required") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(top_k_op_info)
def _top_k_cpu():
"""TopK cpu register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transpose op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
transpose_op_info = CpuRegOp("Transpose") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(transpose_op_info)
def _transpose_cpu():
"""TransposeD cpu register"""
return

View File

@ -213,6 +213,65 @@ class RegOp:
return op_info
class CpuRegOp(RegOp):
"""Class for Cpu op info register"""
def __init__(self, op_name):
super(CpuRegOp, self).__init__(op_name)
self.imply_type = "CPU"
def input(self, index=None, name=None, param_type=None, **kwargs):
"""
Register Cpu op input information.
Args:
index (int): Order of the input. Default: None.
name (str): Name of the input. Default: None.
param_type (str): Param type of the input. Default: None.
kwargs (dict): Other information of the input.
"""
param_list = [index, name, param_type]
key_list = ["index", "name", "param_type"]
fn_list = [self._is_int, self._is_string, self._is_string]
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.inputs.append(input_dict)
return self
def output(self, index=None, name=None, param_type=None, **kwargs):
"""
Register AiCPU op output information.
Args:
index (int): Order of the output. Default: None.
name (str): Name of the output. Default: None.
param_type (str): Param type of the output. Default: None.
kwargs (dict): Other information of the output.
"""
param_list = [index, name, param_type]
key_list = ["index", "name", "param_type"]
fn_list = [self._is_int, self._is_string, self._is_string]
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.outputs.append(output_dict)
return self
def attr(self, name=None, value_type=None, value=None, **kwargs):
"""
Register AiCPU op attribute information.
Args:
name (str): Name of the attribute. Default: None.
value_type (str): Value type of the attribute. Default: None.
value (str): Value of the attribute. Default: None.
kwargs (dict): Other information of the attribute.
"""
param_list = [name, value_type, value]
key_list = ["name", "type", "value"]
fn_list = [self._is_string]
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.attr_.append(attr_dict)
return self
class AkgRegOp(RegOp):
"""Class for Akg op info register."""