forked from mindspore-Ecosystem/mindspore
!16988 upload unify cpu ops info library code
From: @chengxb7532 Reviewed-by: @linqingke,@guoqi1024 Signed-off-by: @guoqi1024
This commit is contained in:
commit
5ed742a893
|
@ -40,22 +40,7 @@ class AdamCPUKernel : public CPUKernel {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Adam,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamCPUKernel)
|
||||
MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -40,10 +40,8 @@ class ArgmaxCPUKernel : public CPUKernel {
|
|||
size_t dim_axis_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArgmaxCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||
ArgmaxCPUKernel, float16);
|
||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCPUKernel, float16);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -42,14 +42,8 @@ class ArgMinWithValueCPUKernel : public CPUKernel {
|
|||
size_t dim_axis_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
ArgMinWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArgMinWithValueCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
ArgMinWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArgMinWithValueCPUKernel, float16);
|
||||
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCPUKernel, float16);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -63,43 +63,18 @@ class ArithmeticCPUKernel : public CPUKernel {
|
|||
TypeId target_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Pow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Pow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCPUKernel, int64_t);
|
||||
|
@ -139,9 +114,6 @@ MS_REG_CPU_KERNEL_T(
|
|||
MS_REG_CPU_KERNEL_T(
|
||||
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
SquaredDifference,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
|
|
|
@ -37,10 +37,7 @@ class BiasAddCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> bias_shape_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(
|
||||
BiasAdd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BiasAddCPUKernel);
|
||||
MS_REG_CPU_KERNEL(BiasAdd, KernelAttr(), BiasAddCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_
|
||||
|
|
|
@ -36,8 +36,7 @@ class BiasAddGradCPUKernel : public CPUKernel {
|
|||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BiasAddGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr(), BiasAddGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIASADDGRADCPUKERNEL_H_
|
||||
|
|
|
@ -34,8 +34,8 @@ void Cast(const S *in, T *out, size_t size) {
|
|||
template <typename S, typename T>
|
||||
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
source_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
target_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
source_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
target_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
template <typename S, typename T>
|
||||
|
|
|
@ -35,309 +35,166 @@ class CastCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
TypeId source_dtype{kTypeUnknown};
|
||||
TypeId target_dtype{kTypeUnknown};
|
||||
TypeId source_dtype_{kTypeUnknown};
|
||||
TypeId target_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
|
||||
bool, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
|
||||
bool, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
|
||||
bool, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
bool, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
bool, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
bool, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
bool, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
bool, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
||||
bool, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
||||
bool, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
||||
bool, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
bool, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, float16, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, float16, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, float16, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
float16, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16),
|
||||
CastCPUKernel, float16, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||
CastCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
||||
CastCPUKernel, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8),
|
||||
CastCPUKernel, float16, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastCPUKernel, float16, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastCPUKernel, float16, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastCPUKernel, float16, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
float16, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, float, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, float, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, float, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
float, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16),
|
||||
CastCPUKernel, float, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
CastCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
CastCPUKernel, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8),
|
||||
CastCPUKernel, float, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastCPUKernel, float, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastCPUKernel, float, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastCPUKernel, float, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
float, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, double, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, double, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, double, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
double, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16),
|
||||
CastCPUKernel, double, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
|
||||
CastCPUKernel, double, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
|
||||
CastCPUKernel, double, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CastCPUKernel, double, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastCPUKernel, double, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastCPUKernel, double, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastCPUKernel, double, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
double, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
|
||||
int8_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
|
||||
int8_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
|
||||
int8_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
int8_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
int8_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
int8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
int8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
int8_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
||||
int8_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
||||
int8_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
||||
int8_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
int8_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, int16_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, int16_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, int16_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
int16_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
int16_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
int16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
int16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
int16_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
||||
int16_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
||||
int16_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
||||
int16_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
int16_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, int32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, int32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, int32_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
int32_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
int32_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
int32_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
||||
int32_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
||||
int32_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
||||
int32_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
int32_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, int64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, int64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, int64_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
int64_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
int64_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
int64_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
||||
int64_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
||||
int64_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
||||
int64_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
int64_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, uint8_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, uint8_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, uint8_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
uint8_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
uint8_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
uint8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
uint8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
uint8_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
||||
uint8_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
||||
uint8_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
||||
uint8_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
uint8_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, uint16_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, uint16_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, uint16_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
uint16_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
uint16_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
uint16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
uint16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
uint16_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastCPUKernel, uint16_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastCPUKernel, uint16_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastCPUKernel, uint16_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
uint16_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, uint32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, uint32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, uint32_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
uint32_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
uint32_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
uint32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
uint32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
uint32_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastCPUKernel, uint32_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastCPUKernel, uint32_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastCPUKernel, uint32_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
uint32_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastCPUKernel, uint64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastCPUKernel, uint64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastCPUKernel, uint64_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
||||
uint64_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
||||
uint64_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
||||
uint64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
||||
uint64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
||||
uint64_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastCPUKernel, uint64_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastCPUKernel, uint64_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastCPUKernel, uint64_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
||||
uint64_t, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -39,36 +39,16 @@ class ConcatCPUKernel : public CPUKernel {
|
|||
CNodeWeakPtr node_wpt_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Concat, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ConcatCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
ConcatCPUKernel, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
ConcatCPUKernel, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ConcatCPUKernel, int)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ConcatCPUKernel, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
ConcatCPUKernel, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
ConcatCPUKernel, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
ConcatCPUKernel, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
ConcatCPUKernel, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ConcatCPUKernel, bool)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "runtime/device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose", "Unpack", "AddN"};
|
||||
CPUKernelFactory &CPUKernelFactory::GetInstance() {
|
||||
static CPUKernelFactory instance;
|
||||
return instance;
|
||||
|
@ -48,6 +50,58 @@ std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_na
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info,
|
||||
std::vector<KernelAttr> *kernel_attrs) {
|
||||
auto inputs_ptr = op_info->inputs_ptr();
|
||||
auto outputs_ptr = op_info->outputs_ptr();
|
||||
auto first_input_dtypes = inputs_ptr[0]->dtypes();
|
||||
auto input_formats = inputs_ptr[0]->formats();
|
||||
|
||||
for (size_t i = 0; i < first_input_dtypes.size(); i++) {
|
||||
KernelAttr kernel_attr;
|
||||
kernel_attr.AddInputAttr(kernel::DtypeToTypeId(first_input_dtypes[i]), input_formats[i]);
|
||||
for (size_t j = 1; j < inputs_ptr.size(); j++) {
|
||||
auto input_dtypes = inputs_ptr[j]->dtypes();
|
||||
input_formats = inputs_ptr[j]->formats();
|
||||
kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
|
||||
}
|
||||
for (size_t j = 0; j < outputs_ptr.size(); j++) {
|
||||
auto output_dtypes = outputs_ptr[j]->dtypes();
|
||||
auto output_formats = outputs_ptr[j]->formats();
|
||||
kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
|
||||
}
|
||||
if (same_op_name.count(op_info->op_name()) != 0) {
|
||||
kernel_attr.SetAllSameAttr(true);
|
||||
}
|
||||
kernel_attrs->emplace_back(kernel_attr);
|
||||
}
|
||||
}
|
||||
|
||||
void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs) {
|
||||
size_t attr_size = kernel_attrs.size();
|
||||
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
if (iter == name_to_attr_creator_.end()) {
|
||||
MS_LOG(ERROR) << "CPUKernelFactory has not registered operator: " << kernel_name;
|
||||
return;
|
||||
}
|
||||
|
||||
if (attr_size <= iter->second.size()) {
|
||||
for (size_t i = 0; i < attr_size; i++) {
|
||||
auto creator = name_to_attr_creator_.find(kernel_name)->second[i].second;
|
||||
attr_creators[i] = std::make_pair(kernel_attrs[i], creator);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "attr size is not equal creators size " << kernel_name << " attr_size = " << attr_size
|
||||
<< " creator_size = " << iter->second.size();
|
||||
auto single_creator = name_to_attr_creator_.find(kernel_name)->second[0].second;
|
||||
for (size_t i = 0; i < attr_size; i++) {
|
||||
attr_creators[i] = std::make_pair(kernel_attrs[i], single_creator);
|
||||
}
|
||||
}
|
||||
name_to_attr_creator_[kernel_name] = attr_creators;
|
||||
}
|
||||
|
||||
std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name,
|
||||
const KernelBuildInfo &kernel_info) {
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
|
@ -55,10 +109,18 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
|
|||
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
auto creators = iter->second;
|
||||
for (size_t index = 0; index < creators.size(); ++index) {
|
||||
auto attr_creator = creators[index];
|
||||
if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) {
|
||||
auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
|
||||
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
||||
if (op_info_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Not find op[" << kernel_name << "] in cpu";
|
||||
}
|
||||
kernel_attrs.clear();
|
||||
SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||
kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(kernel_name, kernel_attrs);
|
||||
}
|
||||
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
||||
if (CPUKernelSingleAttrCheck(kernel_attrs[index], kernel_info)) {
|
||||
return std::make_pair(true, index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,6 +37,8 @@ class CPUKernelFactory {
|
|||
static CPUKernelFactory &GetInstance();
|
||||
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
|
||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
||||
void SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info, std::vector<KernelAttr> *kernel_attrs);
|
||||
void UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs);
|
||||
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
||||
|
||||
private:
|
||||
|
|
|
@ -46,14 +46,7 @@ class DropoutCPUKernel : public CPUKernel {
|
|||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Dropout,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
DropoutCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Dropout,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DropoutCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Dropout, KernelAttr(), DropoutCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_
|
||||
|
|
|
@ -43,14 +43,7 @@ class DropoutGradCpuBwdKernel : public CPUKernel {
|
|||
size_t num_count, float keep_prob);
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
DropoutGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DropoutGradCpuBwdKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
DropoutGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
DropoutGradCpuBwdKernel);
|
||||
MS_REG_CPU_KERNEL(DropoutGrad, KernelAttr(), DropoutGradCpuBwdKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,10 +33,7 @@ class EqualCountCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
EqualCount,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
EqualCountCPUKernel);
|
||||
MS_REG_CPU_KERNEL(EqualCount, KernelAttr(), EqualCountCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -44,41 +44,17 @@ class GatherV2CPUKernel : public CPUKernel {
|
|||
int64_t axis_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
GatherV2CPUKernel, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2CPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2CPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2CPUKernel, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherV2CPUKernel, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2CPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherV2CPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2CPUKernel, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
GatherV2CPUKernel, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherV2CPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
GatherV2CPUKernel, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -39,76 +39,16 @@ class GatherDCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherDCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherDCPUKernel, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherDCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherDCPUKernel, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
GatherDCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
GatherDCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
GatherDCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
GatherDCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherDCPUKernel, bool, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
GatherDCPUKernel, bool, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, bool, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, bool, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -40,44 +40,16 @@ class GatherDGradCPUKernel : public CPUKernel {
|
|||
int32_t axis_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherDGradCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherDGradCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherDGradCPUKernel, int32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherDGradCPUKernel, int32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
GatherDGradCPUKernel, int32_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherDGradCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherDGradCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherDGradCPUKernel, int64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherDGradCPUKernel, int64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
GatherDGradCPUKernel, int64_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, bool);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -47,20 +47,7 @@ class GatherNdCPUKernel : public CPUKernel {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherNdCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherNdCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherNdCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherNdCPUKernel);
|
||||
MS_REG_CPU_KERNEL(GatherNd, KernelAttr(), GatherNdCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -41,20 +41,15 @@ class HSigmoidCPUKernel : public CPUKernel {
|
|||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
HSigmoidCPUKernel, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
HSigmoidCPUKernel, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
HSigmoidCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
HSigmoidCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
HSigmoidCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
|
|
|
@ -41,29 +41,11 @@ class HSigmoidGradCPUKernel : public CPUKernel {
|
|||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
HSigmoidGradCPUKernel, int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSigmoidGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
HSigmoidGradCPUKernel, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSigmoidGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
HSigmoidGradCPUKernel, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSigmoidGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
HSigmoidGradCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSigmoidGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
HSigmoidGradCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
|
|
|
@ -41,20 +41,11 @@ class HSwishCPUKernel : public CPUKernel {
|
|||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), HSwishCPUKernel,
|
||||
int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
HSwishCPUKernel, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
HSwishCPUKernel, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
HSwishCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
HSwishCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
|
|
|
@ -41,29 +41,11 @@ class HSwishGradCPUKernel : public CPUKernel {
|
|||
uint64_t tensor_size_ = 1;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
HSwishGradCPUKernel, int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSwishGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
HSwishGradCPUKernel, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSwishGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
HSwishGradCPUKernel, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSwishGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
HSwishGradCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
HSwishGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
HSwishGradCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
|
|
|
@ -51,41 +51,7 @@ class IsFiniteCPUKernel : public CPUKernel {
|
|||
TypeId input_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
|
||||
IsFiniteCPUKernel);
|
||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr(), IsFiniteCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -46,25 +46,7 @@ class LayerNormCPUKernel : public CPUKernel {
|
|||
size_t param_num_{1};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(LayerNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
LayerNormCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(LayerNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LayerNormCPUKernel);
|
||||
MS_REG_CPU_KERNEL(LayerNorm, KernelAttr(), LayerNormCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_
|
||||
|
|
|
@ -48,29 +48,7 @@ class LayerNormGradCPUKernel : public CPUKernel {
|
|||
size_t param_size_{1};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(LayerNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
LayerNormGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(LayerNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LayerNormGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(LayerNormGrad, KernelAttr(), LayerNormGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -75,33 +75,12 @@ class MaximumCPUKernel : public CPUKernel {
|
|||
const size_t max_dims{7};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
MaximumCPUKernel, int32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Maximum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
MaximumCPUKernel, uint32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Maximum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MaximumCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
MaximumCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Maximum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
MaximumCPUKernel, uint64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Maximum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MaximumCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -47,59 +47,7 @@ class MaximumGradCPUKernel : public CPUKernel {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaximumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
MaximumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MaximumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MaximumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
MaximumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MaximumGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(MaximumGrad, KernelAttr(), MaximumGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MaximumGrad_CPU_KERNEL_H_
|
||||
|
|
|
@ -75,33 +75,12 @@ class MinimumCPUKernel : public CPUKernel {
|
|||
const size_t max_dims{7};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
MinimumCPUKernel, int32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Minimum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
MinimumCPUKernel, uint32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Minimum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MinimumCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
MinimumCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Minimum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
MinimumCPUKernel, uint64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Minimum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MinimumCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -47,59 +47,7 @@ class MinimumGradCPUKernel : public CPUKernel {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MinimumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
MinimumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MinimumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MinimumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
MinimumGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MinimumGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(MinimumGrad, KernelAttr(), MinimumGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MinimumGrad_CPU_KERNEL_H_
|
||||
|
|
|
@ -64,19 +64,7 @@ class MirrorPadCPUKernel : public CPUKernel {
|
|||
int num_paddings_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MirrorPad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
MirrorPadCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MirrorPad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
MirrorPadCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
MirrorPadCPUKernel);
|
||||
MS_REG_CPU_KERNEL(MirrorPad, KernelAttr(), MirrorPadCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -80,20 +80,7 @@ class MirrorPadGradCPUKernel : public CPUKernel {
|
|||
int num_paddings_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MirrorPadGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
MirrorPadGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MirrorPadGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
MirrorPadGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MirrorPadGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
MirrorPadGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(MirrorPadGrad, KernelAttr(), MirrorPadGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -33,10 +33,7 @@ class ConvCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Conv2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ConvCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Conv2D, KernelAttr(), ConvCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Conv3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,9 +37,8 @@ class MulCPUKernel : public MKLCPUKernel {
|
|||
bool need_swap_{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MulCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Mul, KernelAttr(), MulCPUKernel);
|
||||
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -38,13 +38,7 @@ class OneHotCPUKernel : public CPUKernel {
|
|||
size_t axis_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(OneHot,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
OneHotCPUKernel);
|
||||
MS_REG_CPU_KERNEL(OneHot, KernelAttr(), OneHotCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -44,39 +44,17 @@ class PackCpuFwdKernel : public CPUKernel {
|
|||
std::unique_ptr<T *[]> inputs_host_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
PackCpuFwdKernel, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
PackCpuFwdKernel, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
PackCpuFwdKernel, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
PackCpuFwdKernel, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
PackCpuFwdKernel, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
PackCpuFwdKernel, bool)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
PackCpuFwdKernel, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
PackCpuFwdKernel, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
PackCpuFwdKernel, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Stack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
PackCpuFwdKernel, float16)
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Stack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
PackCpuFwdKernel, float)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, float16)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, float)
|
||||
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PACK_CPU_KERNEL_H
|
||||
|
|
|
@ -48,11 +48,7 @@ class PadCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), PadCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), PadCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), PadCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Pad, KernelAttr(), PadCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -41,12 +41,7 @@ class RangeCPUKernel : public CPUKernel {
|
|||
int64_t delta_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), RangeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), RangeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
RangeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
RangeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Range, KernelAttr(), RangeCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -41,47 +41,29 @@ class ReduceCPUKernel : public CPUKernel {
|
|||
std::function<void(const T *, size_t, T *)> reduce_func_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReduceCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReduceCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReduceCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReduceCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ReduceCPUKernel, bool);
|
||||
MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr(), ReduceCPUKernel, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ReduceCPUKernel, bool);
|
||||
MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr(), ReduceCPUKernel, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_
|
||||
|
|
|
@ -57,24 +57,12 @@ class SplitCPUKernel : public CPUKernel {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SplitCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SplitCPUKernel, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SplitCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SplitCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
SplitCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Split,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SplitCPUKernel, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, float16);
|
||||
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, double);
|
||||
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -40,15 +40,9 @@ class TensorAddCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TensorAddCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
TensorAddCPUKernel, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Add, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
TensorAddCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -57,25 +57,7 @@ class TileCPUKernel : public CPUKernel {
|
|||
size_t input_size_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), TileCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Tile, KernelAttr(), TileCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
|
|
|
@ -40,20 +40,7 @@ class TopKCPUKernel : public CPUKernel {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(TopK,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
TopKCPUKernel)
|
||||
MS_REG_CPU_KERNEL(TopK,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
TopKCPUKernel)
|
||||
MS_REG_CPU_KERNEL(TopK, KernelAttr(), TopKCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_
|
||||
|
|
|
@ -52,36 +52,7 @@ class TransposeCPUFwdKernel : public CPUKernel {
|
|||
std::unordered_map<TypeId, TypeKernel> launch_map_;
|
||||
TypeKernel launch_func_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
TransposeCPUFwdKernel);
|
||||
MS_REG_CPU_KERNEL(Transpose, KernelAttr(), TransposeCPUFwdKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_
|
||||
|
|
|
@ -35,7 +35,8 @@ enum KernelType : int {
|
|||
RT_KERNEL,
|
||||
HCCL_KERNEL,
|
||||
TBE_KERNEL,
|
||||
HOST_KERNEL
|
||||
HOST_KERNEL,
|
||||
CPU_KERNEL,
|
||||
};
|
||||
|
||||
namespace kernel {
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU };
|
||||
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU = 2, kCPU };
|
||||
enum OpIOType { kInput = 0, kOutput };
|
||||
|
||||
class OpAttr {
|
||||
|
|
|
@ -50,6 +50,7 @@ constexpr auto kAiCore = "AiCore";
|
|||
constexpr auto kCUDA = "CUDA";
|
||||
constexpr auto kTbe = "TBE";
|
||||
constexpr auto kAkg = "AKG";
|
||||
constexpr auto kCpu = "CPU";
|
||||
constexpr auto kName = "name";
|
||||
constexpr auto kParamType = "param_type";
|
||||
constexpr auto kDtype = "dtype";
|
||||
|
@ -90,6 +91,9 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
|
|||
} else if (imply_type_string == kAiCPU) {
|
||||
OpImplyType imply_type = kAICPU;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else if (imply_type_string == kCpu) {
|
||||
OpImplyType imply_type = kCPU;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support imply_type";
|
||||
}
|
||||
|
|
|
@ -18,7 +18,11 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/kernel_compiler/oplib/opinfo.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "utils/trace_base.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -35,19 +39,17 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats,
|
||||
std::vector<TypeId> *output_types) {
|
||||
void GetOutputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *output_types) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
TypeId dtype = kTypeUnknown;
|
||||
dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
||||
output_formats->emplace_back(kOpFormat_DEFAULT);
|
||||
output_types->emplace_back(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats,
|
||||
std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) {
|
||||
void GetInputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *input_types,
|
||||
std::vector<size_t> *input_no_cnode_indexes) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
TypeId dtype = kTypeUnknown;
|
||||
|
@ -57,7 +59,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
|
|||
} else {
|
||||
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
||||
}
|
||||
input_formats->emplace_back(kOpFormat_DEFAULT);
|
||||
input_types->emplace_back(dtype);
|
||||
}
|
||||
}
|
||||
|
@ -86,16 +87,13 @@ bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
||||
const std::vector<std::string> &output_formats,
|
||||
const std::vector<TypeId> &output_types) {
|
||||
int GetOutputDtypeMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &output_types) {
|
||||
if (kernel_attr.GetOutputSize() != output_types.size()) {
|
||||
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
|
||||
<< ", actual output num:" << output_types.size();
|
||||
return std::make_pair(0, 0);
|
||||
return 0;
|
||||
}
|
||||
int data_type_matched_num = 0;
|
||||
int format_matched_num = 0;
|
||||
auto output_num = output_types.size();
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
if (kernel_attr.GetOutputAttr(i).first != output_types[i]) {
|
||||
|
@ -104,27 +102,17 @@ std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr
|
|||
} else {
|
||||
data_type_matched_num++;
|
||||
}
|
||||
|
||||
if (kernel_attr.GetOutputAttr(i).second != output_formats[i]) {
|
||||
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetOutputAttr(i).second
|
||||
<< ", actual output format:" << output_formats[i];
|
||||
} else {
|
||||
format_matched_num++;
|
||||
}
|
||||
}
|
||||
return std::make_pair(data_type_matched_num, format_matched_num);
|
||||
return data_type_matched_num;
|
||||
}
|
||||
|
||||
std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
||||
const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
|
||||
int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
|
||||
if (kernel_attr.GetInputSize() != input_types.size()) {
|
||||
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
||||
return std::make_pair(0, 0);
|
||||
return 0;
|
||||
}
|
||||
int data_type_matched_num = 0;
|
||||
int format_matched_num = 0;
|
||||
auto input_num = input_types.size();
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
|
||||
|
@ -132,10 +120,9 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
|||
<< ", actual input dtype:" << input_types[i];
|
||||
} else {
|
||||
data_type_matched_num++;
|
||||
format_matched_num++;
|
||||
}
|
||||
}
|
||||
return std::make_pair(data_type_matched_num, format_matched_num);
|
||||
return data_type_matched_num;
|
||||
}
|
||||
|
||||
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
||||
|
@ -197,12 +184,9 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
|
|||
}
|
||||
} // namespace
|
||||
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
||||
const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types, const std::vector<size_t> &input_not_cnode_indexes,
|
||||
const std::vector<std::string> &infer_output_formats, const std::vector<TypeId> &infer_output_types,
|
||||
const std::vector<KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes, const std::vector<TypeId> &output_types,
|
||||
std::pair<bool, bool> *matched, bool strict) {
|
||||
int max_type_matched_num = -1;
|
||||
int max_format_matched_num = -1;
|
||||
for (auto kernel_attr : kernel_attrs) {
|
||||
if (kernel_attr.GetAllSame()) {
|
||||
ExpandKernelAttr(kernel_node, &kernel_attr);
|
||||
|
@ -212,32 +196,14 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
|||
MS_LOG(DEBUG) << "Output num is not equal!";
|
||||
continue;
|
||||
}
|
||||
std::pair<int, int> input_type_format_matched_num =
|
||||
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes, strict);
|
||||
std::pair<int, int> output_type_format_matched_num =
|
||||
GetOutputDtypeFormatMatchedNum(kernel_attr, infer_output_formats, infer_output_types);
|
||||
// Data type first
|
||||
if (input_type_format_matched_num.first > max_type_matched_num) {
|
||||
max_type_matched_num = input_type_format_matched_num.first;
|
||||
max_format_matched_num = input_type_format_matched_num.second;
|
||||
*selected_kernel_attr = kernel_attr;
|
||||
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
||||
input_type_format_matched_num.second > max_format_matched_num) {
|
||||
max_format_matched_num = input_type_format_matched_num.second;
|
||||
*selected_kernel_attr = kernel_attr;
|
||||
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
||||
input_type_format_matched_num.second == max_format_matched_num) {
|
||||
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
||||
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
||||
*selected_kernel_attr = kernel_attr;
|
||||
}
|
||||
}
|
||||
int input_dtype_matched_num =
|
||||
GetInputDtypeFormatMatchedNum(kernel_attr, input_types, input_not_cnode_indexes, strict);
|
||||
int output_dtype_matched_num = GetOutputDtypeMatchedNum(kernel_attr, output_types);
|
||||
// All formats and data types matched
|
||||
if (input_type_format_matched_num.first == SizeToInt(input_types.size()) &&
|
||||
input_type_format_matched_num.second == SizeToInt(input_types.size())) {
|
||||
if (input_dtype_matched_num == SizeToInt(input_types.size())) {
|
||||
*selected_kernel_attr = kernel_attr;
|
||||
matched->first = true;
|
||||
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
||||
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
||||
if (output_dtype_matched_num == SizeToInt(output_types.size())) {
|
||||
matched->second = true;
|
||||
return true;
|
||||
}
|
||||
|
@ -249,43 +215,50 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
std::vector<size_t> input_not_cnode_indexes;
|
||||
std::vector<std::string> output_formats;
|
||||
std::vector<std::string> selected_output_formats;
|
||||
std::vector<TypeId> output_types;
|
||||
std::vector<std::string> infer_output_formats;
|
||||
std::vector<TypeId> infer_output_types;
|
||||
std::vector<TypeId> selected_output_types;
|
||||
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto kernel_attrs =
|
||||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||
if (kernel_attrs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node)
|
||||
<< "] is not support. Trace: " << trace::DumpSourceLines(kernel_node);
|
||||
if (kernel_attrs.empty() || (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0)) {
|
||||
MS_LOG(DEBUG) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] will get ops attr info.";
|
||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
|
||||
if (op_info_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Not find op[" << op_name << "] in cpu";
|
||||
}
|
||||
kernel_attrs.clear();
|
||||
kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||
kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
|
||||
}
|
||||
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
||||
GetOutputInferFormatsAndDtypes(kernel_node, &infer_output_formats, &infer_output_types);
|
||||
GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes);
|
||||
GetOutputDtypes(kernel_node, &output_types);
|
||||
KernelAttr selected_kernel_attr;
|
||||
std::pair<bool, bool> matched = std::make_pair(false, false);
|
||||
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
|
||||
input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) {
|
||||
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
|
||||
output_types, &matched, true)) {
|
||||
if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") {
|
||||
KernelNotSupportException(kernel_node, input_types, infer_output_types);
|
||||
KernelNotSupportException(kernel_node, input_types, output_types);
|
||||
}
|
||||
matched = std::make_pair(false, false);
|
||||
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes,
|
||||
infer_output_formats, infer_output_types, &matched, false);
|
||||
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes, output_types,
|
||||
&matched, false);
|
||||
if (!matched.first) {
|
||||
KernelNotSupportException(kernel_node, input_types, infer_output_types);
|
||||
KernelNotSupportException(kernel_node, input_types, output_types);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected_kernel_attr.GetInputSize() > 0 &&
|
||||
(matched.first || input_types.size() == input_not_cnode_indexes.size())) {
|
||||
MS_LOG(INFO) << "Input format and dtype is matched";
|
||||
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
|
||||
for (size_t i = 0; i < selected_kernel_attr.GetInputSize(); ++i) {
|
||||
input_types[SizeToInt(i)] = selected_kernel_attr.GetInputAttr(i).first;
|
||||
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &selected_output_formats, &selected_output_types);
|
||||
for (size_t index = 0; index < selected_kernel_attr.GetInputSize(); index++) {
|
||||
input_types[index] = selected_kernel_attr.GetInputAttr(index).first;
|
||||
input_formats.emplace_back(selected_kernel_attr.GetInputAttr(index).second);
|
||||
}
|
||||
}
|
||||
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
|
||||
SetKernelBuildInfo(input_formats, input_types, selected_output_formats, selected_output_types, kernel_node.get());
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
|
|
|
@ -23,7 +23,7 @@ Examples:
|
|||
|
||||
from .primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
||||
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
||||
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
|
||||
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, CpuRegOp, DataType
|
||||
from .primitive import constexpr
|
||||
from . import composite, operations, functional
|
||||
from . import signature
|
||||
|
@ -36,7 +36,7 @@ __primitive__ = [
|
|||
]
|
||||
|
||||
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
||||
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType",
|
||||
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "CpuRegOp", "DataType",
|
||||
"constexpr"]
|
||||
__all__.extend(__primitive__)
|
||||
__all__.extend(composite.__all__)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import platform
|
||||
from .aicpu import *
|
||||
from .cpu import *
|
||||
if "Windows" not in platform.system():
|
||||
from .akg import *
|
||||
from .tbe import *
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""cpu ops"""
|
||||
from .cast import _cast_cpu
|
||||
from .mul import _mul_cpu
|
||||
from .sub import _sub_cpu
|
||||
from .pow import _pow_cpu
|
||||
from .real_div import _real_div_cpu
|
||||
from .div import _div_cpu
|
||||
from .concat import _concat_cpu
|
||||
from .split import _split_cpu
|
||||
from .adam import _adam_cpu
|
||||
from .arg_max import _arg_max_cpu
|
||||
from .arg_min_with_value import _arg_min_with_value_cpu
|
||||
from .bias_add import _bias_add_cpu
|
||||
from .bias_add_grad import _bias_add_grad_cpu
|
||||
from .dropout import _dropout_cpu
|
||||
from .dropout_grad import _dropout_grad_cpu
|
||||
from .gather_d import _gather_cpu
|
||||
from .gather_d_grad import _gather_d_grad_cpu
|
||||
from .gather_v2 import _gather_v2_cpu
|
||||
from .gather_nd import _gather_nd_cpu
|
||||
from .maximum import _maximum_cpu
|
||||
from .maximum_grad import _maximum_grad_cpu
|
||||
from .conv2d import _conv2d_cpu
|
||||
from .hsigmoid import _hsigmoid_cpu
|
||||
from .hsigmoid_grad import _hsigmoid_grad_cpu
|
||||
from .hswish import _hswish_cpu
|
||||
from .hswish_grad import _hswish_grad_cpu
|
||||
from .is_finite import _is_finite_cpu
|
||||
from .layer_norm import _layer_norm_cpu
|
||||
from .layer_norm_grad import _layer_norm_grad_cpu
|
||||
from .minimum import _minimum_cpu
|
||||
from .minimum_grad import _minimum_grad_cpu
|
||||
from .equal_count import _equal_count_cpu
|
||||
from .mirror_pad import _mirror_pad_cpu
|
||||
from .mirror_pad_grad import _mirror_pad_grad_cpu
|
||||
from .stack import _stack_cpu
|
||||
from .reduce_mean import _reduce_mean_cpu
|
||||
from .reduce_max import _reduce_max_cpu
|
||||
from .reduce_sum import _reduce_sum_cpu
|
||||
from .reduce_min import _reduce_min_cpu
|
||||
from .reduce_all import _reduce_all_cpu
|
||||
from .reduce_any import _reduce_any_cpu
|
||||
from .transpose import _transpose_cpu
|
||||
from .tile import _tile_cpu
|
||||
from .top_k import _top_k_cpu
|
||||
from .add import _add_cpu
|
||||
from .one_hot import _one_hot_cpu
|
||||
from .pad import _pad_cpu
|
||||
from .range import _range_cpu
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Concat op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
adam_op_info = CpuRegOp("Adam") \
|
||||
.input(0, "var", "required") \
|
||||
.input(1, "m", "required") \
|
||||
.input(2, "v", "required") \
|
||||
.input(3, "beta1_power", "required") \
|
||||
.input(4, "beta2_power", "required") \
|
||||
.input(5, "lr", "required") \
|
||||
.input(6, "beta1", "required") \
|
||||
.input(7, "beta2", "required") \
|
||||
.input(8, "epsilon", "required") \
|
||||
.input(9, "gradient", "required") \
|
||||
.output(0, "output0", "required") \
|
||||
.output(1, "output0", "required") \
|
||||
.output(2, "output0", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(adam_op_info)
|
||||
def _adam_cpu():
|
||||
"""Adam cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Add op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
add_op_info = CpuRegOp("Add") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(add_op_info)
|
||||
def _add_cpu():
|
||||
"""Add cpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Argmax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
arg_max_op_info = CpuRegOp("Argmax") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(arg_max_op_info)
|
||||
def _arg_max_cpu():
|
||||
"""Argmax cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ArgMinWithValue op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
arg_min_with_value_op_info = CpuRegOp("ArgMinWithValue") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "indice", "required") \
|
||||
.output(1, "values", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(arg_min_with_value_op_info)
|
||||
def _arg_min_with_value_cpu():
|
||||
"""ArgMinWithValue cpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BiasAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
bias_add_op_info = CpuRegOp("BiasAdd") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "bias", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bias_add_op_info)
|
||||
def _bias_add_cpu():
|
||||
"""BiasAdd cpu register"""
|
||||
return
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BiasAddGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
bias_add_grad_op_info = CpuRegOp("BiasAddGrad") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bias_add_grad_op_info)
|
||||
def _bias_add_grad_cpu():
|
||||
"""BiasAddGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Cast op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
cast_op_info = CpuRegOp("Cast") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(cast_op_info)
|
||||
def _cast_cpu():
|
||||
"""Cast Cpu register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Concat op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
concat_op_info = CpuRegOp("Concat") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(concat_op_info)
|
||||
def _concat_cpu():
|
||||
"""Concat cpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Conv2D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
conv2d_op_info = CpuRegOp("Conv2D") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "filter", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(conv2d_op_info)
|
||||
def _conv2d_cpu():
|
||||
"""Conv2D cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Div op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
div_op_info = CpuRegOp("Div") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(div_op_info)
|
||||
def _div_cpu():
|
||||
"""Div cpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dropout op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
dropout_op_info = CpuRegOp("Dropout") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "output0", "required") \
|
||||
.output(1, "output1", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(dropout_op_info)
|
||||
def _dropout_cpu():
|
||||
"""Dropout cpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""DropoutGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
dropout_grad_op_info = CpuRegOp("DropoutGrad") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(dropout_grad_op_info)
|
||||
def _dropout_grad_cpu():
|
||||
"""DropoutGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""EqualCount op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
equal_count_op_info = CpuRegOp("EqualCount") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(equal_count_op_info)
|
||||
def _equal_count_cpu():
|
||||
"""EqualCount cpu register"""
|
||||
return
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""GatherD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
gather_d_op_info = CpuRegOp("GatherD") \
|
||||
.input(0, "input", "required") \
|
||||
.input(1, "dim", "required") \
|
||||
.input(2, "index", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
|
||||
DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
|
||||
DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
|
||||
DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
|
||||
DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
|
||||
DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
|
||||
DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
|
||||
DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
|
||||
DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
|
||||
DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
|
||||
DataType.I64_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(gather_d_op_info)
|
||||
def _gather_cpu():
|
||||
"""GatherD cpu register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""GatherDGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
gather_d_grad_op_info = CpuRegOp("GatherDGrad") \
|
||||
.input(0, "index", "required") \
|
||||
.input(1, "src", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(gather_d_grad_op_info)
|
||||
def _gather_d_grad_cpu():
|
||||
"""GatherDGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""GatherNd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
gather_nd_op_info = CpuRegOp("GatherNd") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(gather_nd_op_info)
|
||||
def _gather_nd_cpu():
|
||||
"""GatherNd cpu register"""
|
||||
return
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""GatherV2 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
gather_v2_op_info = CpuRegOp("Gather") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "indices", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(gather_v2_op_info)
|
||||
def _gather_v2_cpu():
|
||||
"""GatherV2 cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HSigmoid op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
hsigmoid_op_info = CpuRegOp("HSigmoid") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hsigmoid_op_info)
|
||||
def _hsigmoid_cpu():
|
||||
"""HSigmoid cpu register"""
|
||||
return
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HSigmoidGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
hsigmoidgrad_op_info = CpuRegOp("HSigmoidGrad") \
|
||||
.input(0, "y_grad") \
|
||||
.input(1, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hsigmoidgrad_op_info)
|
||||
def _hsigmoid_grad_cpu():
|
||||
"""HSigmoidGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HSwish op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
hswish_op_info = CpuRegOp("HSwish") \
|
||||
.input(0, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hswish_op_info)
|
||||
def _hswish_cpu():
|
||||
"""HSwish cpu register"""
|
||||
return
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HSwishGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
hswish_grad_op_info = CpuRegOp("HSwishGrad") \
|
||||
.input(0, "y_grad") \
|
||||
.input(1, "x") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(hswish_grad_op_info)
|
||||
def _hswish_grad_cpu():
|
||||
"""HSwishGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""IsFinite op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
is_finite_op_info = CpuRegOp("IsFinite") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(is_finite_op_info)
|
||||
def _is_finite_cpu():
|
||||
"""IsFinite cpu register"""
|
||||
return
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LayerNorm op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
layer_norm_op_info = CpuRegOp("LayerNorm") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "gamma", "required") \
|
||||
.input(2, "beta", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.output(1, "mean", "required") \
|
||||
.output(2, "variance", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(layer_norm_op_info)
|
||||
def _layer_norm_cpu():
|
||||
"""LayerNorm cpu register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LayerNormGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
layer_norm_grad_op_info = CpuRegOp("LayerNormGrad") \
|
||||
.input(0, "dy", "required") \
|
||||
.input(1, "x", "required") \
|
||||
.input(2, "variance", "required") \
|
||||
.input(3, "mean", "required") \
|
||||
.input(4, "gamma", "required") \
|
||||
.output(0, "pd_x", "required") \
|
||||
.output(1, "pd_gamma", "required") \
|
||||
.output(2, "pd_beta", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(layer_norm_grad_op_info)
|
||||
def _layer_norm_grad_cpu():
|
||||
"""LayerNormGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Maximum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
maximum_op_info = CpuRegOp("Maximum") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(maximum_op_info)
|
||||
def _maximum_cpu():
|
||||
"""Maximum cpu register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MaximumGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
maximum_grad_op_info = CpuRegOp("MaximumGrad") \
|
||||
.input(0, "grads", "required") \
|
||||
.input(1, "x1", "required") \
|
||||
.input(2, "x2", "required") \
|
||||
.output(0, "y1", "required") \
|
||||
.output(1, "y2", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default,
|
||||
DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default,
|
||||
DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default,
|
||||
DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(maximum_grad_op_info)
|
||||
def _maximum_grad_cpu():
|
||||
"""MaximumGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
"""Minimum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
minimum_op_info = CpuRegOp("Minimum") \
|
||||
.input(0, "x1", "required") \
|
||||
.input(1, "x2", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(minimum_op_info)
|
||||
def _minimum_cpu():
|
||||
"""Minimum cpu register"""
|
||||
return
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MinimumGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
minimum_grad_op_info = CpuRegOp("MinimumGrad") \
|
||||
.input(0, "grads", "required") \
|
||||
.input(1, "x1", "required") \
|
||||
.input(2, "x2", "required") \
|
||||
.output(0, "y1", "required") \
|
||||
.output(1, "y2", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
|
||||
DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(minimum_grad_op_info)
|
||||
def _minimum_grad_cpu():
|
||||
"""MinimumGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MirrorPad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
mirror_pad_op_info = CpuRegOp("MirrorPad") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "paddings", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mirror_pad_op_info)
|
||||
def _mirror_pad_cpu():
|
||||
"""MirrorPad cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""MirrorPadGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
mirror_pad_grad_op_info = CpuRegOp("MirrorPadGrad") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "paddings", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mirror_pad_grad_op_info)
|
||||
def _mirror_pad_grad_cpu():
|
||||
"""MirrorPadGrad cpu register"""
|
||||
return
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Mul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
mul_op_info = CpuRegOp("Mul") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mul_op_info)
|
||||
def _mul_cpu():
|
||||
"""Mul cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""OneHot op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
one_hot_op_info = CpuRegOp("OneHot") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "on_value", "required") \
|
||||
.input(2, "off_value", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(one_hot_op_info)
|
||||
def _one_hot_cpu():
|
||||
"""OneHot cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Pad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
pad_op_info = CpuRegOp("Pad") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(pad_op_info)
|
||||
def _pad_cpu():
|
||||
"""Pad cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Power op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
pow_op_info = CpuRegOp("Pow") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(pow_op_info)
|
||||
def _pow_cpu():
|
||||
"""Pow cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Range op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
range_op_info = CpuRegOp("Range") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(range_op_info)
|
||||
def _range_cpu():
|
||||
"""Range cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""RealDiv op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
real_div_op_info = CpuRegOp("RealDiv") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(real_div_op_info)
|
||||
def _real_div_cpu():
|
||||
"""RealDiv cpu register"""
|
||||
return
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceAll op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
reduce_all_op_info = CpuRegOp("ReduceAll") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reduce_all_op_info)
|
||||
def _reduce_all_cpu():
|
||||
"""ReduceAll cpu register"""
|
||||
return
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceAny op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
reduce_any_op_info = CpuRegOp("ReduceAny") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reduce_any_op_info)
|
||||
def _reduce_any_cpu():
|
||||
"""ReduceAny cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceMax op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
reduce_max_op_info = CpuRegOp("ReduceMax") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reduce_max_op_info)
|
||||
def _reduce_max_cpu():
|
||||
"""ReduceMax cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceMean op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
reduce_mean_op_info = CpuRegOp("ReduceMean") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reduce_mean_op_info)
|
||||
def _reduce_mean_cpu():
|
||||
"""ReduceMean cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceMin op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
reduce_min_op_info = CpuRegOp("ReduceMin") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reduce_min_op_info)
|
||||
def _reduce_min_cpu():
|
||||
"""ReduceMin cpu register"""
|
||||
return
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceSum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
reduce_sum_op_info = CpuRegOp("ReduceSum") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reduce_sum_op_info)
|
||||
def _reduce_sum_cpu():
|
||||
"""ReduceSum cpu register"""
|
||||
return
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Split op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
split_op_info = CpuRegOp("Split") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(split_op_info)
|
||||
def _split_cpu():
|
||||
"""Split cpu register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Stack op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
stack_op_info = CpuRegOp("Stack") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(stack_op_info)
|
||||
def _stack_cpu():
|
||||
"""Stack cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sub op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
sub_op_info = CpuRegOp("Sub") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sub_op_info)
|
||||
def _sub_cpu():
|
||||
"""Sub cpu register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Tile op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
tile_op_info = CpuRegOp("Tile") \
|
||||
.input(0, "x1", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(tile_op_info)
|
||||
def _tile_cpu():
|
||||
"""Tile cpu register"""
|
||||
return
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""TopK op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
top_k_op_info = CpuRegOp("TopK") \
|
||||
.input(0, "input", "required") \
|
||||
.input(1, "k", "required") \
|
||||
.output(0, "values", "required") \
|
||||
.output(1, "indices", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(top_k_op_info)
|
||||
def _top_k_cpu():
|
||||
"""TopK cpu register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Transpose op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
transpose_op_info = CpuRegOp("Transpose") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(transpose_op_info)
|
||||
def _transpose_cpu():
|
||||
"""TransposeD cpu register"""
|
||||
return
|
|
@ -213,6 +213,65 @@ class RegOp:
|
|||
return op_info
|
||||
|
||||
|
||||
class CpuRegOp(RegOp):
|
||||
"""Class for Cpu op info register"""
|
||||
|
||||
def __init__(self, op_name):
|
||||
super(CpuRegOp, self).__init__(op_name)
|
||||
self.imply_type = "CPU"
|
||||
|
||||
def input(self, index=None, name=None, param_type=None, **kwargs):
|
||||
"""
|
||||
Register Cpu op input information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the input. Default: None.
|
||||
name (str): Name of the input. Default: None.
|
||||
param_type (str): Param type of the input. Default: None.
|
||||
kwargs (dict): Other information of the input.
|
||||
"""
|
||||
param_list = [index, name, param_type]
|
||||
key_list = ["index", "name", "param_type"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.inputs.append(input_dict)
|
||||
return self
|
||||
|
||||
def output(self, index=None, name=None, param_type=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op output information.
|
||||
|
||||
Args:
|
||||
index (int): Order of the output. Default: None.
|
||||
name (str): Name of the output. Default: None.
|
||||
param_type (str): Param type of the output. Default: None.
|
||||
kwargs (dict): Other information of the output.
|
||||
"""
|
||||
param_list = [index, name, param_type]
|
||||
key_list = ["index", "name", "param_type"]
|
||||
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.outputs.append(output_dict)
|
||||
return self
|
||||
|
||||
def attr(self, name=None, value_type=None, value=None, **kwargs):
|
||||
"""
|
||||
Register AiCPU op attribute information.
|
||||
|
||||
Args:
|
||||
name (str): Name of the attribute. Default: None.
|
||||
value_type (str): Value type of the attribute. Default: None.
|
||||
value (str): Value of the attribute. Default: None.
|
||||
kwargs (dict): Other information of the attribute.
|
||||
"""
|
||||
param_list = [name, value_type, value]
|
||||
key_list = ["name", "type", "value"]
|
||||
fn_list = [self._is_string]
|
||||
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||
self.attr_.append(attr_dict)
|
||||
return self
|
||||
|
||||
|
||||
class AkgRegOp(RegOp):
|
||||
"""Class for Akg op info register."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue