forked from mindspore-Ecosystem/mindspore
!16988 upload unify cpu ops info library code
From: @chengxb7532 Reviewed-by: @linqingke,@guoqi1024 Signed-off-by: @guoqi1024
This commit is contained in:
commit
5ed742a893
|
@ -40,22 +40,7 @@ class AdamCPUKernel : public CPUKernel {
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Adam,
|
MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCPUKernel)
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
AdamCPUKernel)
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -40,10 +40,8 @@ class ArgmaxCPUKernel : public CPUKernel {
|
||||||
size_t dim_axis_;
|
size_t dim_axis_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCPUKernel, float);
|
||||||
ArgmaxCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCPUKernel, float16);
|
||||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ArgmaxCPUKernel, float16);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -42,14 +42,8 @@ class ArgMinWithValueCPUKernel : public CPUKernel {
|
||||||
size_t dim_axis_;
|
size_t dim_axis_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCPUKernel, float);
|
||||||
ArgMinWithValue,
|
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCPUKernel, float16);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ArgMinWithValueCPUKernel, float);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
ArgMinWithValue,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
ArgMinWithValueCPUKernel, float16);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -63,43 +63,18 @@ class ArithmeticCPUKernel : public CPUKernel {
|
||||||
TypeId target_dtype_{kTypeUnknown};
|
TypeId target_dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||||
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, float);
|
||||||
ArithmeticCPUKernel, int);
|
MS_REG_CPU_KERNEL_T(Sub, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||||
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, float);
|
||||||
ArithmeticCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(Pow, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||||
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, float);
|
||||||
ArithmeticCPUKernel, int64_t);
|
MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||||
Pow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, float);
|
||||||
ArithmeticCPUKernel, int);
|
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ArithmeticCPUKernel, float);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Pow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ArithmeticCPUKernel, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ArithmeticCPUKernel, int);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
RealDiv,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ArithmeticCPUKernel, float);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ArithmeticCPUKernel, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ArithmeticCPUKernel, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ArithmeticCPUKernel, int);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ArithmeticCPUKernel, float);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(
|
||||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
ArithmeticCPUKernel, int64_t);
|
ArithmeticCPUKernel, int64_t);
|
||||||
|
@ -139,9 +114,6 @@ MS_REG_CPU_KERNEL_T(
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(
|
||||||
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
ArithmeticCPUKernel, int64_t);
|
ArithmeticCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ArithmeticCPUKernel, int);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(
|
||||||
SquaredDifference,
|
SquaredDifference,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
|
|
@ -37,10 +37,7 @@ class BiasAddCPUKernel : public CPUKernel {
|
||||||
std::vector<size_t> input_shape_;
|
std::vector<size_t> input_shape_;
|
||||||
std::vector<size_t> bias_shape_;
|
std::vector<size_t> bias_shape_;
|
||||||
};
|
};
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(BiasAdd, KernelAttr(), BiasAddCPUKernel);
|
||||||
BiasAdd,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
BiasAddCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_
|
||||||
|
|
|
@ -36,8 +36,7 @@ class BiasAddGradCPUKernel : public CPUKernel {
|
||||||
private:
|
private:
|
||||||
std::vector<size_t> input_shape_;
|
std::vector<size_t> input_shape_;
|
||||||
};
|
};
|
||||||
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr(), BiasAddGradCPUKernel);
|
||||||
BiasAddGradCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIASADDGRADCPUKERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIASADDGRADCPUKERNEL_H_
|
||||||
|
|
|
@ -34,8 +34,8 @@ void Cast(const S *in, T *out, size_t size) {
|
||||||
template <typename S, typename T>
|
template <typename S, typename T>
|
||||||
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
|
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
source_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
source_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
target_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
target_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename S, typename T>
|
template <typename S, typename T>
|
||||||
|
|
|
@ -35,309 +35,166 @@ class CastCPUKernel : public CPUKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TypeId source_dtype{kTypeUnknown};
|
TypeId source_dtype_{kTypeUnknown};
|
||||||
TypeId target_dtype{kTypeUnknown};
|
TypeId target_dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint8_t);
|
||||||
bool, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint32_t);
|
||||||
bool, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int8_t);
|
||||||
bool, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int32_t);
|
||||||
bool, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, float16);
|
||||||
bool, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, double);
|
||||||
bool, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint8_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
bool, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
bool, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
|
||||||
bool, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
|
||||||
bool, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
|
||||||
bool, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
bool, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint8_t);
|
||||||
CastCPUKernel, float16, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint32_t);
|
||||||
CastCPUKernel, float16, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int8_t);
|
||||||
CastCPUKernel, float16, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int32_t);
|
||||||
float16, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, float16);
|
||||||
CastCPUKernel, float16, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, double);
|
||||||
CastCPUKernel, float16, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint16_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
CastCPUKernel, float16, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
CastCPUKernel, float16, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
CastCPUKernel, float16, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
CastCPUKernel, float16, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
CastCPUKernel, float16, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
float16, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint8_t);
|
||||||
CastCPUKernel, float, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint32_t);
|
||||||
CastCPUKernel, float, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int8_t);
|
||||||
CastCPUKernel, float, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int32_t);
|
||||||
float, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, float16);
|
||||||
CastCPUKernel, float, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, double);
|
||||||
CastCPUKernel, float, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint32_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
CastCPUKernel, float, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
CastCPUKernel, float, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
CastCPUKernel, float, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
CastCPUKernel, float, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
CastCPUKernel, float, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
float, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint8_t);
|
||||||
CastCPUKernel, double, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint32_t);
|
||||||
CastCPUKernel, double, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int8_t);
|
||||||
CastCPUKernel, double, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int32_t);
|
||||||
double, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, float16);
|
||||||
CastCPUKernel, double, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, double);
|
||||||
CastCPUKernel, double, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, uint64_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
CastCPUKernel, double, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
CastCPUKernel, double, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
CastCPUKernel, double, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
CastCPUKernel, double, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
CastCPUKernel, double, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
double, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint8_t);
|
||||||
int8_t, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint32_t);
|
||||||
int8_t, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int8_t);
|
||||||
int8_t, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int32_t);
|
||||||
int8_t, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, float16);
|
||||||
int8_t, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, double);
|
||||||
int8_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int8_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
int8_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
int8_t, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
|
||||||
int8_t, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
|
||||||
int8_t, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
|
||||||
int8_t, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
int8_t, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint8_t);
|
||||||
CastCPUKernel, int16_t, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint32_t);
|
||||||
CastCPUKernel, int16_t, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int8_t);
|
||||||
CastCPUKernel, int16_t, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int32_t);
|
||||||
int16_t, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, float16);
|
||||||
int16_t, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, double);
|
||||||
int16_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int16_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
int16_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
int16_t, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
|
||||||
int16_t, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
|
||||||
int16_t, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
|
||||||
int16_t, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
int16_t, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint8_t);
|
||||||
CastCPUKernel, int32_t, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint32_t);
|
||||||
CastCPUKernel, int32_t, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int8_t);
|
||||||
CastCPUKernel, int32_t, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int32_t);
|
||||||
int32_t, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, float16);
|
||||||
int32_t, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, double);
|
||||||
int32_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int32_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
int32_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
int32_t, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
|
||||||
int32_t, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
|
||||||
int32_t, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
|
||||||
int32_t, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
int32_t, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint8_t);
|
||||||
CastCPUKernel, int64_t, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint32_t);
|
||||||
CastCPUKernel, int64_t, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int8_t);
|
||||||
CastCPUKernel, int64_t, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int32_t);
|
||||||
int64_t, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, float16);
|
||||||
int64_t, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, double);
|
||||||
int64_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, int64_t, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
int64_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
int64_t, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
|
||||||
int64_t, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
|
||||||
int64_t, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
|
||||||
int64_t, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
int64_t, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint8_t);
|
||||||
CastCPUKernel, uint8_t, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint32_t);
|
||||||
CastCPUKernel, uint8_t, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int8_t);
|
||||||
CastCPUKernel, uint8_t, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int32_t);
|
||||||
uint8_t, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, float16);
|
||||||
uint8_t, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, double);
|
||||||
uint8_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float16, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
uint8_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
uint8_t, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
|
|
||||||
uint8_t, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
|
|
||||||
uint8_t, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
|
|
||||||
uint8_t, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
uint8_t, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint8_t);
|
||||||
CastCPUKernel, uint16_t, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint32_t);
|
||||||
CastCPUKernel, uint16_t, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int8_t);
|
||||||
CastCPUKernel, uint16_t, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int32_t);
|
||||||
uint16_t, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, float16);
|
||||||
uint16_t, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, double);
|
||||||
uint16_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, float, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
uint16_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
uint16_t, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
CastCPUKernel, uint16_t, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
CastCPUKernel, uint16_t, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
CastCPUKernel, uint16_t, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
uint16_t, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint8_t);
|
||||||
CastCPUKernel, uint32_t, float16);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint32_t);
|
||||||
CastCPUKernel, uint32_t, float);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int8_t);
|
||||||
CastCPUKernel, uint32_t, double);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int16_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int32_t);
|
||||||
uint32_t, int8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, float16);
|
||||||
uint32_t, int16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, double);
|
||||||
uint32_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, double, bool);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
uint32_t, int64_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint8_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint16_t);
|
||||||
uint32_t, uint8_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint32_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, uint64_t);
|
||||||
CastCPUKernel, uint32_t, uint16_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int8_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int16_t);
|
||||||
CastCPUKernel, uint32_t, uint32_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int32_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64),
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, int64_t);
|
||||||
CastCPUKernel, uint32_t, uint64_t);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, float16);
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, float);
|
||||||
uint32_t, bool);
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, double);
|
||||||
|
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCPUKernel, bool, bool);
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
CastCPUKernel, uint64_t, float16);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
CastCPUKernel, uint64_t, float);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
CastCPUKernel, uint64_t, double);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
|
|
||||||
uint64_t, int8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
|
|
||||||
uint64_t, int16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
|
|
||||||
uint64_t, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
|
|
||||||
uint64_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
|
|
||||||
uint64_t, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
CastCPUKernel, uint64_t, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
CastCPUKernel, uint64_t, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
CastCPUKernel, uint64_t, uint64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
|
|
||||||
uint64_t, bool);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -39,36 +39,16 @@ class ConcatCPUKernel : public CPUKernel {
|
||||||
CNodeWeakPtr node_wpt_;
|
CNodeWeakPtr node_wpt_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, float);
|
||||||
Concat, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int8_t)
|
||||||
ConcatCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int16_t)
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int32_t)
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, int64_t)
|
||||||
ConcatCPUKernel, int8_t)
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint8_t)
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint16_t)
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint32_t)
|
||||||
ConcatCPUKernel, int16_t)
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, uint64_t)
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCPUKernel, bool)
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ConcatCPUKernel, int)
|
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ConcatCPUKernel, int64_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
ConcatCPUKernel, uint8_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
ConcatCPUKernel, uint16_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
ConcatCPUKernel, uint32_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
ConcatCPUKernel, uint64_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Concat,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
||||||
ConcatCPUKernel, bool)
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
|
||||||
using mindspore::kernel::Address;
|
using mindspore::kernel::Address;
|
||||||
|
|
|
@ -17,12 +17,14 @@
|
||||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "runtime/device/kernel_info.h"
|
#include "runtime/device/kernel_info.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose", "Unpack", "AddN"};
|
||||||
CPUKernelFactory &CPUKernelFactory::GetInstance() {
|
CPUKernelFactory &CPUKernelFactory::GetInstance() {
|
||||||
static CPUKernelFactory instance;
|
static CPUKernelFactory instance;
|
||||||
return instance;
|
return instance;
|
||||||
|
@ -48,6 +50,58 @@ std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_na
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info,
|
||||||
|
std::vector<KernelAttr> *kernel_attrs) {
|
||||||
|
auto inputs_ptr = op_info->inputs_ptr();
|
||||||
|
auto outputs_ptr = op_info->outputs_ptr();
|
||||||
|
auto first_input_dtypes = inputs_ptr[0]->dtypes();
|
||||||
|
auto input_formats = inputs_ptr[0]->formats();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < first_input_dtypes.size(); i++) {
|
||||||
|
KernelAttr kernel_attr;
|
||||||
|
kernel_attr.AddInputAttr(kernel::DtypeToTypeId(first_input_dtypes[i]), input_formats[i]);
|
||||||
|
for (size_t j = 1; j < inputs_ptr.size(); j++) {
|
||||||
|
auto input_dtypes = inputs_ptr[j]->dtypes();
|
||||||
|
input_formats = inputs_ptr[j]->formats();
|
||||||
|
kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
|
||||||
|
}
|
||||||
|
for (size_t j = 0; j < outputs_ptr.size(); j++) {
|
||||||
|
auto output_dtypes = outputs_ptr[j]->dtypes();
|
||||||
|
auto output_formats = outputs_ptr[j]->formats();
|
||||||
|
kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
|
||||||
|
}
|
||||||
|
if (same_op_name.count(op_info->op_name()) != 0) {
|
||||||
|
kernel_attr.SetAllSameAttr(true);
|
||||||
|
}
|
||||||
|
kernel_attrs->emplace_back(kernel_attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs) {
|
||||||
|
size_t attr_size = kernel_attrs.size();
|
||||||
|
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
|
||||||
|
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||||
|
if (iter == name_to_attr_creator_.end()) {
|
||||||
|
MS_LOG(ERROR) << "CPUKernelFactory has not registered operator: " << kernel_name;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (attr_size <= iter->second.size()) {
|
||||||
|
for (size_t i = 0; i < attr_size; i++) {
|
||||||
|
auto creator = name_to_attr_creator_.find(kernel_name)->second[i].second;
|
||||||
|
attr_creators[i] = std::make_pair(kernel_attrs[i], creator);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "attr size is not equal creators size " << kernel_name << " attr_size = " << attr_size
|
||||||
|
<< " creator_size = " << iter->second.size();
|
||||||
|
auto single_creator = name_to_attr_creator_.find(kernel_name)->second[0].second;
|
||||||
|
for (size_t i = 0; i < attr_size; i++) {
|
||||||
|
attr_creators[i] = std::make_pair(kernel_attrs[i], single_creator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name_to_attr_creator_[kernel_name] = attr_creators;
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name,
|
std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name,
|
||||||
const KernelBuildInfo &kernel_info) {
|
const KernelBuildInfo &kernel_info) {
|
||||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||||
|
@ -55,10 +109,18 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
|
||||||
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
||||||
return std::make_pair(false, 0);
|
return std::make_pair(false, 0);
|
||||||
}
|
}
|
||||||
auto creators = iter->second;
|
auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
|
||||||
for (size_t index = 0; index < creators.size(); ++index) {
|
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
||||||
auto attr_creator = creators[index];
|
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
||||||
if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) {
|
if (op_info_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Not find op[" << kernel_name << "] in cpu";
|
||||||
|
}
|
||||||
|
kernel_attrs.clear();
|
||||||
|
SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||||
|
kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(kernel_name, kernel_attrs);
|
||||||
|
}
|
||||||
|
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
||||||
|
if (CPUKernelSingleAttrCheck(kernel_attrs[index], kernel_info)) {
|
||||||
return std::make_pair(true, index);
|
return std::make_pair(true, index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -36,6 +37,8 @@ class CPUKernelFactory {
|
||||||
static CPUKernelFactory &GetInstance();
|
static CPUKernelFactory &GetInstance();
|
||||||
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
|
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
|
||||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
||||||
|
void SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info, std::vector<KernelAttr> *kernel_attrs);
|
||||||
|
void UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs);
|
||||||
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -46,14 +46,7 @@ class DropoutCPUKernel : public CPUKernel {
|
||||||
uint64_t tensor_size_ = 1;
|
uint64_t tensor_size_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(Dropout, KernelAttr(), DropoutCPUKernel);
|
||||||
Dropout,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
DropoutCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
Dropout,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
DropoutCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_
|
||||||
|
|
|
@ -43,14 +43,7 @@ class DropoutGradCpuBwdKernel : public CPUKernel {
|
||||||
size_t num_count, float keep_prob);
|
size_t num_count, float keep_prob);
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(DropoutGrad, KernelAttr(), DropoutGradCpuBwdKernel);
|
||||||
DropoutGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
DropoutGradCpuBwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
DropoutGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
DropoutGradCpuBwdKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,7 @@ class EqualCountCPUKernel : public CPUKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(EqualCount, KernelAttr(), EqualCountCPUKernel);
|
||||||
EqualCount,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
EqualCountCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -44,41 +44,17 @@ class GatherV2CPUKernel : public CPUKernel {
|
||||||
int64_t axis_{0};
|
int64_t axis_{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint8_t);
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint16_t);
|
||||||
GatherV2CPUKernel, bool);
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint32_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, uint64_t);
|
||||||
Gather,
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int8_t);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int16_t);
|
||||||
GatherV2CPUKernel, float);
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int32_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, int64_t);
|
||||||
Gather,
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, float);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, double);
|
||||||
GatherV2CPUKernel, double);
|
MS_REG_CPU_KERNEL_T(Gather, KernelAttr(), GatherV2CPUKernel, bool);
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
|
||||||
GatherV2CPUKernel, int8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
|
||||||
GatherV2CPUKernel, int16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
GatherV2CPUKernel, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
GatherV2CPUKernel, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
GatherV2CPUKernel, uint8_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
GatherV2CPUKernel, uint16_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
GatherV2CPUKernel, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
GatherV2CPUKernel, uint64_t);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -39,76 +39,16 @@ class GatherDCPUKernel : public CPUKernel {
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float, int32_t);
|
||||||
KernelAttr()
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float, int64_t);
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float16, int32_t);
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, float16, int64_t);
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int32_t, int32_t);
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int32_t, int64_t);
|
||||||
GatherDCPUKernel, float, int32_t);
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int64_t, int32_t);
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, int64_t, int64_t);
|
||||||
KernelAttr()
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, bool, int32_t);
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
MS_REG_CPU_KERNEL_T_S(GatherD, KernelAttr(), GatherDCPUKernel, bool, int64_t);
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
GatherDCPUKernel, float, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
GatherDCPUKernel, float16, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
GatherDCPUKernel, float16, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
|
||||||
GatherDCPUKernel, int32_t, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
|
||||||
GatherDCPUKernel, int32_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
|
||||||
GatherDCPUKernel, int64_t, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
|
||||||
GatherDCPUKernel, int64_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeBool),
|
|
||||||
GatherDCPUKernel, bool, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeBool),
|
|
||||||
GatherDCPUKernel, bool, int64_t);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -40,44 +40,16 @@ class GatherDGradCPUKernel : public CPUKernel {
|
||||||
int32_t axis_;
|
int32_t axis_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, int32_t);
|
||||||
GatherDGrad,
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, int64_t);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, float);
|
||||||
GatherDGradCPUKernel, int32_t, int32_t);
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, float16);
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int32_t, bool);
|
||||||
GatherDGrad,
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, int32_t);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, int64_t);
|
||||||
GatherDGradCPUKernel, int32_t, int64_t);
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, float);
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, float16);
|
||||||
GatherDGrad,
|
MS_REG_CPU_KERNEL_T_S(GatherDGrad, KernelAttr(), GatherDGradCPUKernel, int64_t, bool);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
GatherDGradCPUKernel, int32_t, float);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
|
||||||
GatherDGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
GatherDGradCPUKernel, int32_t, float16);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
|
||||||
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
||||||
GatherDGradCPUKernel, int32_t, bool);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
|
||||||
GatherDGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
GatherDGradCPUKernel, int64_t, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
|
||||||
GatherDGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
GatherDGradCPUKernel, int64_t, int64_t);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
|
||||||
GatherDGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
GatherDGradCPUKernel, int64_t, float);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
|
||||||
GatherDGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
GatherDGradCPUKernel, int64_t, float16);
|
|
||||||
MS_REG_CPU_KERNEL_T_S(
|
|
||||||
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
||||||
GatherDGradCPUKernel, int64_t, bool);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -47,20 +47,7 @@ class GatherNdCPUKernel : public CPUKernel {
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(GatherNd, KernelAttr(), GatherNdCPUKernel);
|
||||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
GatherNdCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
GatherNdCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
GatherNd,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
GatherNdCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
GatherNd,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
GatherNdCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -41,20 +41,15 @@ class HSigmoidCPUKernel : public CPUKernel {
|
||||||
uint64_t tensor_size_ = 1;
|
uint64_t tensor_size_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int8_t);
|
||||||
HSigmoidCPUKernel, int8_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int16_t);
|
||||||
HSigmoidCPUKernel, int16_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int32_t);
|
||||||
HSigmoidCPUKernel, int);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, int64_t);
|
||||||
HSigmoidCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(HSigmoid, KernelAttr(), HSigmoidCPUKernel, float);
|
||||||
HSigmoidCPUKernel, float);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||||
|
|
|
@ -41,29 +41,11 @@ class HSigmoidGradCPUKernel : public CPUKernel {
|
||||||
uint64_t tensor_size_ = 1;
|
uint64_t tensor_size_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int8_t);
|
||||||
HSigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int16_t);
|
||||||
HSigmoidGradCPUKernel, int8_t);
|
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int32_t);
|
||||||
|
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(HSigmoidGrad, KernelAttr(), HSigmoidGradCPUKernel, float);
|
||||||
HSigmoidGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
|
||||||
HSigmoidGradCPUKernel, int16_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
HSigmoidGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
HSigmoidGradCPUKernel, int);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
HSigmoidGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
HSigmoidGradCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
HSigmoidGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
HSigmoidGradCPUKernel, float);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||||
|
|
|
@ -41,20 +41,11 @@ class HSwishCPUKernel : public CPUKernel {
|
||||||
uint64_t tensor_size_ = 1;
|
uint64_t tensor_size_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), HSwishCPUKernel,
|
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int8_t);
|
||||||
int8_t);
|
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int16_t);
|
||||||
|
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int32_t);
|
||||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, int64_t);
|
||||||
HSwishCPUKernel, int16_t);
|
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr(), HSwishCPUKernel, float);
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
HSwishCPUKernel, int);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
HSwishCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(HSwish, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
HSwishCPUKernel, float);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||||
|
|
|
@ -41,29 +41,11 @@ class HSwishGradCPUKernel : public CPUKernel {
|
||||||
uint64_t tensor_size_ = 1;
|
uint64_t tensor_size_ = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int8_t);
|
||||||
HSwishGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int16_t);
|
||||||
HSwishGradCPUKernel, int8_t);
|
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int32_t);
|
||||||
|
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(HSwishGrad, KernelAttr(), HSwishGradCPUKernel, float);
|
||||||
HSwishGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
|
||||||
HSwishGradCPUKernel, int16_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
HSwishGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
HSwishGradCPUKernel, int);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
HSwishGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
HSwishGradCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
HSwishGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
HSwishGradCPUKernel, float);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||||
|
|
|
@ -51,41 +51,7 @@ class IsFiniteCPUKernel : public CPUKernel {
|
||||||
TypeId input_dtype_{kTypeUnknown};
|
TypeId input_dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
MS_REG_CPU_KERNEL(IsFinite, KernelAttr(), IsFiniteCPUKernel);
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(IsFinite, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
|
|
||||||
IsFiniteCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -46,25 +46,7 @@ class LayerNormCPUKernel : public CPUKernel {
|
||||||
size_t param_num_{1};
|
size_t param_num_{1};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(LayerNorm,
|
MS_REG_CPU_KERNEL(LayerNorm, KernelAttr(), LayerNormCPUKernel);
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
LayerNormCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(LayerNorm,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
LayerNormCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_
|
||||||
|
|
|
@ -48,29 +48,7 @@ class LayerNormGradCPUKernel : public CPUKernel {
|
||||||
size_t param_size_{1};
|
size_t param_size_{1};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(LayerNormGrad,
|
MS_REG_CPU_KERNEL(LayerNormGrad, KernelAttr(), LayerNormGradCPUKernel);
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
LayerNormGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(LayerNormGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
LayerNormGradCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_
|
||||||
|
|
|
@ -75,33 +75,12 @@ class MaximumCPUKernel : public CPUKernel {
|
||||||
const size_t max_dims{7};
|
const size_t max_dims{7};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, int32_t);
|
||||||
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, int64_t);
|
||||||
MaximumCPUKernel, int32_t);
|
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, uint32_t);
|
||||||
|
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, uint64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, float);
|
||||||
Maximum,
|
MS_REG_CPU_KERNEL_T(Maximum, KernelAttr(), MaximumCPUKernel, double);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
MaximumCPUKernel, uint32_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Maximum,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
MaximumCPUKernel, float);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
MaximumCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Maximum,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
MaximumCPUKernel, uint64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Maximum,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
MaximumCPUKernel, double);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -47,59 +47,7 @@ class MaximumGradCPUKernel : public CPUKernel {
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
MS_REG_CPU_KERNEL(MaximumGrad, KernelAttr(), MaximumGradCPUKernel);
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
|
||||||
MaximumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
MaximumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
MaximumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
|
||||||
MaximumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
MaximumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MaximumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
MaximumGradCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MaximumGrad_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MaximumGrad_CPU_KERNEL_H_
|
||||||
|
|
|
@ -75,33 +75,12 @@ class MinimumCPUKernel : public CPUKernel {
|
||||||
const size_t max_dims{7};
|
const size_t max_dims{7};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, int32_t);
|
||||||
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, uint32_t);
|
||||||
MinimumCPUKernel, int32_t);
|
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, float);
|
||||||
|
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, uint64_t);
|
||||||
Minimum,
|
MS_REG_CPU_KERNEL_T(Minimum, KernelAttr(), MinimumCPUKernel, double);
|
||||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
MinimumCPUKernel, uint32_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Minimum,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
MinimumCPUKernel, float);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
MinimumCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Minimum,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
MinimumCPUKernel, uint64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Minimum,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
MinimumCPUKernel, double);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -47,59 +47,7 @@ class MinimumGradCPUKernel : public CPUKernel {
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
MS_REG_CPU_KERNEL(MinimumGrad, KernelAttr(), MinimumGradCPUKernel);
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
|
||||||
MinimumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddInputAttr(kNumberTypeUInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
MinimumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
MinimumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
|
||||||
MinimumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddInputAttr(kNumberTypeUInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
MinimumGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(MinimumGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
MinimumGradCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MinimumGrad_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MinimumGrad_CPU_KERNEL_H_
|
||||||
|
|
|
@ -64,19 +64,7 @@ class MirrorPadCPUKernel : public CPUKernel {
|
||||||
int num_paddings_;
|
int num_paddings_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(MirrorPad, KernelAttr(), MirrorPadCPUKernel);
|
||||||
MirrorPad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
MirrorPadCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
MirrorPad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
MirrorPadCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
MirrorPadCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_
|
||||||
|
|
|
@ -80,20 +80,7 @@ class MirrorPadGradCPUKernel : public CPUKernel {
|
||||||
int num_paddings_;
|
int num_paddings_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(MirrorPadGrad, KernelAttr(), MirrorPadGradCPUKernel);
|
||||||
MirrorPadGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
MirrorPadGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
MirrorPadGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
MirrorPadGradCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
MirrorPadGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
MirrorPadGradCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MIRROR_PAD_CPU_KERNEL_H_
|
||||||
|
|
|
@ -33,10 +33,7 @@ class ConvCPUKernel : public MKLCPUKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(Conv2D, KernelAttr(), ConvCPUKernel);
|
||||||
Conv2D,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ConvCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(
|
||||||
Conv3D,
|
Conv3D,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h"
|
||||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -36,9 +37,8 @@ class MulCPUKernel : public MKLCPUKernel {
|
||||||
bool need_swap_{false};
|
bool need_swap_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(Mul, KernelAttr(), MulCPUKernel);
|
||||||
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||||
MulCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -38,13 +38,7 @@ class OneHotCPUKernel : public CPUKernel {
|
||||||
size_t axis_;
|
size_t axis_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(OneHot,
|
MS_REG_CPU_KERNEL(OneHot, KernelAttr(), OneHotCPUKernel);
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
OneHotCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -44,39 +44,17 @@ class PackCpuFwdKernel : public CPUKernel {
|
||||||
std::unique_ptr<T *[]> inputs_host_;
|
std::unique_ptr<T *[]> inputs_host_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int8_t)
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int16_t)
|
||||||
PackCpuFwdKernel, int8_t)
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int32_t)
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, int64_t)
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint8_t)
|
||||||
PackCpuFwdKernel, int16_t)
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint16_t)
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint32_t)
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, uint64_t)
|
||||||
PackCpuFwdKernel, int32_t)
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, float16)
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, float)
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
MS_REG_CPU_KERNEL_T(Stack, KernelAttr(), PackCpuFwdKernel, bool)
|
||||||
PackCpuFwdKernel, int64_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
PackCpuFwdKernel, uint8_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
||||||
PackCpuFwdKernel, bool)
|
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
PackCpuFwdKernel, uint16_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
PackCpuFwdKernel, uint32_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(Stack,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
PackCpuFwdKernel, uint64_t)
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Stack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
PackCpuFwdKernel, float16)
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Stack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
PackCpuFwdKernel, float)
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_PACK_CPU_KERNEL_H
|
#endif // MINDSPORE_PACK_CPU_KERNEL_H
|
||||||
|
|
|
@ -48,11 +48,7 @@ class PadCPUKernel : public CPUKernel {
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), PadCPUKernel);
|
MS_REG_CPU_KERNEL(Pad, KernelAttr(), PadCPUKernel);
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), PadCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Pad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), PadCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PAD_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PAD_CPU_KERNEL_H_
|
||||||
|
|
|
@ -41,12 +41,7 @@ class RangeCPUKernel : public CPUKernel {
|
||||||
int64_t delta_;
|
int64_t delta_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), RangeCPUKernel);
|
MS_REG_CPU_KERNEL(Range, KernelAttr(), RangeCPUKernel);
|
||||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), RangeCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
RangeCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
RangeCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -41,47 +41,29 @@ class ReduceCPUKernel : public CPUKernel {
|
||||||
std::function<void(const T *, size_t, T *)> reduce_func_;
|
std::function<void(const T *, size_t, T *)> reduce_func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, float);
|
||||||
ReduceCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, double);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||||
ReduceCPUKernel, double);
|
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ReduceCPUKernel, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ReduceCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, float);
|
||||||
ReduceCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, double);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||||
ReduceCPUKernel, double);
|
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ReduceCPUKernel, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ReduceCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, float);
|
||||||
ReduceCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, double);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||||
ReduceCPUKernel, double);
|
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ReduceCPUKernel, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ReduceCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, float);
|
||||||
ReduceCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, double);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int32_t);
|
||||||
ReduceCPUKernel, double);
|
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ReduceCPUKernel, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
ReduceCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr(), ReduceCPUKernel, bool);
|
||||||
ReduceCPUKernel, bool);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr(), ReduceCPUKernel, bool);
|
||||||
ReduceCPUKernel, bool);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_
|
||||||
|
|
|
@ -57,24 +57,12 @@ class SplitCPUKernel : public CPUKernel {
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, float);
|
||||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, float16);
|
||||||
SplitCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, double);
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, int32_t);
|
||||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, uint32_t);
|
||||||
SplitCPUKernel, float16);
|
MS_REG_CPU_KERNEL_T(Split, KernelAttr(), SplitCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Split, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
SplitCPUKernel, double);
|
|
||||||
MS_REG_CPU_KERNEL_T(Split,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
SplitCPUKernel, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(Split,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
SplitCPUKernel, uint32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(Split,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
SplitCPUKernel, int64_t);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -40,15 +40,9 @@ class TensorAddCPUKernel : public CPUKernel {
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, int32_t);
|
||||||
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, uint32_t);
|
||||||
TensorAddCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(Add, KernelAttr(), TensorAddCPUKernel, float);
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
TensorAddCPUKernel, int);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
Add, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
TensorAddCPUKernel, uint32_t);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -57,25 +57,7 @@ class TileCPUKernel : public CPUKernel {
|
||||||
size_t input_size_;
|
size_t input_size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), TileCPUKernel);
|
MS_REG_CPU_KERNEL(Tile, KernelAttr(), TileCPUKernel);
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), TileCPUKernel);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||||
|
|
|
@ -40,20 +40,7 @@ class TopKCPUKernel : public CPUKernel {
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(TopK,
|
MS_REG_CPU_KERNEL(TopK, KernelAttr(), TopKCPUKernel)
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
|
||||||
TopKCPUKernel)
|
|
||||||
MS_REG_CPU_KERNEL(TopK,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
|
||||||
TopKCPUKernel)
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_
|
||||||
|
|
|
@ -52,36 +52,7 @@ class TransposeCPUFwdKernel : public CPUKernel {
|
||||||
std::unordered_map<TypeId, TypeKernel> launch_map_;
|
std::unordered_map<TypeId, TypeKernel> launch_map_;
|
||||||
TypeKernel launch_func_;
|
TypeKernel launch_func_;
|
||||||
};
|
};
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
MS_REG_CPU_KERNEL(Transpose, KernelAttr(), TransposeCPUFwdKernel);
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Transpose,
|
|
||||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
||||||
TransposeCPUFwdKernel);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_
|
||||||
|
|
|
@ -35,7 +35,8 @@ enum KernelType : int {
|
||||||
RT_KERNEL,
|
RT_KERNEL,
|
||||||
HCCL_KERNEL,
|
HCCL_KERNEL,
|
||||||
TBE_KERNEL,
|
TBE_KERNEL,
|
||||||
HOST_KERNEL
|
HOST_KERNEL,
|
||||||
|
CPU_KERNEL,
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU };
|
enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU = 2, kCPU };
|
||||||
enum OpIOType { kInput = 0, kOutput };
|
enum OpIOType { kInput = 0, kOutput };
|
||||||
|
|
||||||
class OpAttr {
|
class OpAttr {
|
||||||
|
|
|
@ -50,6 +50,7 @@ constexpr auto kAiCore = "AiCore";
|
||||||
constexpr auto kCUDA = "CUDA";
|
constexpr auto kCUDA = "CUDA";
|
||||||
constexpr auto kTbe = "TBE";
|
constexpr auto kTbe = "TBE";
|
||||||
constexpr auto kAkg = "AKG";
|
constexpr auto kAkg = "AKG";
|
||||||
|
constexpr auto kCpu = "CPU";
|
||||||
constexpr auto kName = "name";
|
constexpr auto kName = "name";
|
||||||
constexpr auto kParamType = "param_type";
|
constexpr auto kParamType = "param_type";
|
||||||
constexpr auto kDtype = "dtype";
|
constexpr auto kDtype = "dtype";
|
||||||
|
@ -90,6 +91,9 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
|
||||||
} else if (imply_type_string == kAiCPU) {
|
} else if (imply_type_string == kAiCPU) {
|
||||||
OpImplyType imply_type = kAICPU;
|
OpImplyType imply_type = kAICPU;
|
||||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||||
|
} else if (imply_type_string == kCpu) {
|
||||||
|
OpImplyType imply_type = kCPU;
|
||||||
|
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Not support imply_type";
|
MS_LOG(ERROR) << "Not support imply_type";
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,11 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||||
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||||
|
#include "backend/kernel_compiler/oplib/opinfo.h"
|
||||||
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
#include "utils/trace_base.h"
|
#include "utils/trace_base.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -35,19 +39,17 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats,
|
void GetOutputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *output_types) {
|
||||||
std::vector<TypeId> *output_types) {
|
|
||||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||||
TypeId dtype = kTypeUnknown;
|
TypeId dtype = kTypeUnknown;
|
||||||
dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
||||||
output_formats->emplace_back(kOpFormat_DEFAULT);
|
|
||||||
output_types->emplace_back(dtype);
|
output_types->emplace_back(dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats,
|
void GetInputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *input_types,
|
||||||
std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) {
|
std::vector<size_t> *input_no_cnode_indexes) {
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||||
TypeId dtype = kTypeUnknown;
|
TypeId dtype = kTypeUnknown;
|
||||||
|
@ -57,7 +59,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
|
||||||
} else {
|
} else {
|
||||||
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
||||||
}
|
}
|
||||||
input_formats->emplace_back(kOpFormat_DEFAULT);
|
|
||||||
input_types->emplace_back(dtype);
|
input_types->emplace_back(dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -86,16 +87,13 @@ bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
int GetOutputDtypeMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &output_types) {
|
||||||
const std::vector<std::string> &output_formats,
|
|
||||||
const std::vector<TypeId> &output_types) {
|
|
||||||
if (kernel_attr.GetOutputSize() != output_types.size()) {
|
if (kernel_attr.GetOutputSize() != output_types.size()) {
|
||||||
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
|
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
|
||||||
<< ", actual output num:" << output_types.size();
|
<< ", actual output num:" << output_types.size();
|
||||||
return std::make_pair(0, 0);
|
return 0;
|
||||||
}
|
}
|
||||||
int data_type_matched_num = 0;
|
int data_type_matched_num = 0;
|
||||||
int format_matched_num = 0;
|
|
||||||
auto output_num = output_types.size();
|
auto output_num = output_types.size();
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
if (kernel_attr.GetOutputAttr(i).first != output_types[i]) {
|
if (kernel_attr.GetOutputAttr(i).first != output_types[i]) {
|
||||||
|
@ -104,27 +102,17 @@ std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr
|
||||||
} else {
|
} else {
|
||||||
data_type_matched_num++;
|
data_type_matched_num++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kernel_attr.GetOutputAttr(i).second != output_formats[i]) {
|
|
||||||
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetOutputAttr(i).second
|
|
||||||
<< ", actual output format:" << output_formats[i];
|
|
||||||
} else {
|
|
||||||
format_matched_num++;
|
|
||||||
}
|
}
|
||||||
}
|
return data_type_matched_num;
|
||||||
return std::make_pair(data_type_matched_num, format_matched_num);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &input_types,
|
||||||
const std::vector<std::string> &input_formats,
|
|
||||||
const std::vector<TypeId> &input_types,
|
|
||||||
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
|
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
|
||||||
if (kernel_attr.GetInputSize() != input_types.size()) {
|
if (kernel_attr.GetInputSize() != input_types.size()) {
|
||||||
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
||||||
return std::make_pair(0, 0);
|
return 0;
|
||||||
}
|
}
|
||||||
int data_type_matched_num = 0;
|
int data_type_matched_num = 0;
|
||||||
int format_matched_num = 0;
|
|
||||||
auto input_num = input_types.size();
|
auto input_num = input_types.size();
|
||||||
for (size_t i = 0; i < input_num; ++i) {
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
|
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
|
||||||
|
@ -132,10 +120,9 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
||||||
<< ", actual input dtype:" << input_types[i];
|
<< ", actual input dtype:" << input_types[i];
|
||||||
} else {
|
} else {
|
||||||
data_type_matched_num++;
|
data_type_matched_num++;
|
||||||
format_matched_num++;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_pair(data_type_matched_num, format_matched_num);
|
return data_type_matched_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
||||||
|
@ -197,12 +184,9 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
||||||
const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats,
|
const std::vector<KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
|
||||||
const std::vector<TypeId> &input_types, const std::vector<size_t> &input_not_cnode_indexes,
|
const std::vector<size_t> &input_not_cnode_indexes, const std::vector<TypeId> &output_types,
|
||||||
const std::vector<std::string> &infer_output_formats, const std::vector<TypeId> &infer_output_types,
|
|
||||||
std::pair<bool, bool> *matched, bool strict) {
|
std::pair<bool, bool> *matched, bool strict) {
|
||||||
int max_type_matched_num = -1;
|
|
||||||
int max_format_matched_num = -1;
|
|
||||||
for (auto kernel_attr : kernel_attrs) {
|
for (auto kernel_attr : kernel_attrs) {
|
||||||
if (kernel_attr.GetAllSame()) {
|
if (kernel_attr.GetAllSame()) {
|
||||||
ExpandKernelAttr(kernel_node, &kernel_attr);
|
ExpandKernelAttr(kernel_node, &kernel_attr);
|
||||||
|
@ -212,32 +196,14 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
||||||
MS_LOG(DEBUG) << "Output num is not equal!";
|
MS_LOG(DEBUG) << "Output num is not equal!";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::pair<int, int> input_type_format_matched_num =
|
int input_dtype_matched_num =
|
||||||
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes, strict);
|
GetInputDtypeFormatMatchedNum(kernel_attr, input_types, input_not_cnode_indexes, strict);
|
||||||
std::pair<int, int> output_type_format_matched_num =
|
int output_dtype_matched_num = GetOutputDtypeMatchedNum(kernel_attr, output_types);
|
||||||
GetOutputDtypeFormatMatchedNum(kernel_attr, infer_output_formats, infer_output_types);
|
|
||||||
// Data type first
|
|
||||||
if (input_type_format_matched_num.first > max_type_matched_num) {
|
|
||||||
max_type_matched_num = input_type_format_matched_num.first;
|
|
||||||
max_format_matched_num = input_type_format_matched_num.second;
|
|
||||||
*selected_kernel_attr = kernel_attr;
|
|
||||||
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
|
||||||
input_type_format_matched_num.second > max_format_matched_num) {
|
|
||||||
max_format_matched_num = input_type_format_matched_num.second;
|
|
||||||
*selected_kernel_attr = kernel_attr;
|
|
||||||
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
|
||||||
input_type_format_matched_num.second == max_format_matched_num) {
|
|
||||||
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
|
||||||
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
|
||||||
*selected_kernel_attr = kernel_attr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// All formats and data types matched
|
// All formats and data types matched
|
||||||
if (input_type_format_matched_num.first == SizeToInt(input_types.size()) &&
|
if (input_dtype_matched_num == SizeToInt(input_types.size())) {
|
||||||
input_type_format_matched_num.second == SizeToInt(input_types.size())) {
|
*selected_kernel_attr = kernel_attr;
|
||||||
matched->first = true;
|
matched->first = true;
|
||||||
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
if (output_dtype_matched_num == SizeToInt(output_types.size())) {
|
||||||
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
|
||||||
matched->second = true;
|
matched->second = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -249,43 +215,50 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
std::vector<std::string> input_formats;
|
std::vector<std::string> input_formats;
|
||||||
std::vector<TypeId> input_types;
|
std::vector<TypeId> input_types;
|
||||||
std::vector<size_t> input_not_cnode_indexes;
|
std::vector<size_t> input_not_cnode_indexes;
|
||||||
std::vector<std::string> output_formats;
|
std::vector<std::string> selected_output_formats;
|
||||||
std::vector<TypeId> output_types;
|
std::vector<TypeId> output_types;
|
||||||
std::vector<std::string> infer_output_formats;
|
std::vector<TypeId> selected_output_types;
|
||||||
std::vector<TypeId> infer_output_types;
|
|
||||||
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
||||||
auto kernel_attrs =
|
auto kernel_attrs =
|
||||||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||||
if (kernel_attrs.empty()) {
|
if (kernel_attrs.empty() || (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0)) {
|
||||||
MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node)
|
MS_LOG(DEBUG) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] will get ops attr info.";
|
||||||
<< "] is not support. Trace: " << trace::DumpSourceLines(kernel_node);
|
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
|
||||||
|
if (op_info_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Not find op[" << op_name << "] in cpu";
|
||||||
}
|
}
|
||||||
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
kernel_attrs.clear();
|
||||||
GetOutputInferFormatsAndDtypes(kernel_node, &infer_output_formats, &infer_output_types);
|
kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||||
|
kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
|
||||||
|
}
|
||||||
|
GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes);
|
||||||
|
GetOutputDtypes(kernel_node, &output_types);
|
||||||
KernelAttr selected_kernel_attr;
|
KernelAttr selected_kernel_attr;
|
||||||
std::pair<bool, bool> matched = std::make_pair(false, false);
|
std::pair<bool, bool> matched = std::make_pair(false, false);
|
||||||
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
|
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
|
||||||
input_not_cnode_indexes, infer_output_formats, infer_output_types, &matched, true)) {
|
output_types, &matched, true)) {
|
||||||
if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") {
|
if (AnfAlgo::GetCNodeName(kernel_node) == "Cast") {
|
||||||
KernelNotSupportException(kernel_node, input_types, infer_output_types);
|
KernelNotSupportException(kernel_node, input_types, output_types);
|
||||||
}
|
}
|
||||||
matched = std::make_pair(false, false);
|
matched = std::make_pair(false, false);
|
||||||
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types, input_not_cnode_indexes,
|
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes, output_types,
|
||||||
infer_output_formats, infer_output_types, &matched, false);
|
&matched, false);
|
||||||
if (!matched.first) {
|
if (!matched.first) {
|
||||||
KernelNotSupportException(kernel_node, input_types, infer_output_types);
|
KernelNotSupportException(kernel_node, input_types, output_types);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selected_kernel_attr.GetInputSize() > 0 &&
|
if (selected_kernel_attr.GetInputSize() > 0 &&
|
||||||
(matched.first || input_types.size() == input_not_cnode_indexes.size())) {
|
(matched.first || input_types.size() == input_not_cnode_indexes.size())) {
|
||||||
MS_LOG(INFO) << "Input format and dtype is matched";
|
MS_LOG(INFO) << "Input format and dtype is matched";
|
||||||
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
|
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &selected_output_formats, &selected_output_types);
|
||||||
for (size_t i = 0; i < selected_kernel_attr.GetInputSize(); ++i) {
|
for (size_t index = 0; index < selected_kernel_attr.GetInputSize(); index++) {
|
||||||
input_types[SizeToInt(i)] = selected_kernel_attr.GetInputAttr(i).first;
|
input_types[index] = selected_kernel_attr.GetInputAttr(index).first;
|
||||||
|
input_formats.emplace_back(selected_kernel_attr.GetInputAttr(index).second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
|
SetKernelBuildInfo(input_formats, input_types, selected_output_formats, selected_output_types, kernel_node.get());
|
||||||
}
|
}
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace device
|
} // namespace device
|
||||||
|
|
|
@ -23,7 +23,7 @@ Examples:
|
||||||
|
|
||||||
from .primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
from .primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
||||||
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
||||||
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
|
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, CpuRegOp, DataType
|
||||||
from .primitive import constexpr
|
from .primitive import constexpr
|
||||||
from . import composite, operations, functional
|
from . import composite, operations, functional
|
||||||
from . import signature
|
from . import signature
|
||||||
|
@ -36,7 +36,7 @@ __primitive__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
||||||
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType",
|
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "CpuRegOp", "DataType",
|
||||||
"constexpr"]
|
"constexpr"]
|
||||||
__all__.extend(__primitive__)
|
__all__.extend(__primitive__)
|
||||||
__all__.extend(composite.__all__)
|
__all__.extend(composite.__all__)
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
from .aicpu import *
|
from .aicpu import *
|
||||||
|
from .cpu import *
|
||||||
if "Windows" not in platform.system():
|
if "Windows" not in platform.system():
|
||||||
from .akg import *
|
from .akg import *
|
||||||
from .tbe import *
|
from .tbe import *
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""cpu ops"""
|
||||||
|
from .cast import _cast_cpu
|
||||||
|
from .mul import _mul_cpu
|
||||||
|
from .sub import _sub_cpu
|
||||||
|
from .pow import _pow_cpu
|
||||||
|
from .real_div import _real_div_cpu
|
||||||
|
from .div import _div_cpu
|
||||||
|
from .concat import _concat_cpu
|
||||||
|
from .split import _split_cpu
|
||||||
|
from .adam import _adam_cpu
|
||||||
|
from .arg_max import _arg_max_cpu
|
||||||
|
from .arg_min_with_value import _arg_min_with_value_cpu
|
||||||
|
from .bias_add import _bias_add_cpu
|
||||||
|
from .bias_add_grad import _bias_add_grad_cpu
|
||||||
|
from .dropout import _dropout_cpu
|
||||||
|
from .dropout_grad import _dropout_grad_cpu
|
||||||
|
from .gather_d import _gather_cpu
|
||||||
|
from .gather_d_grad import _gather_d_grad_cpu
|
||||||
|
from .gather_v2 import _gather_v2_cpu
|
||||||
|
from .gather_nd import _gather_nd_cpu
|
||||||
|
from .maximum import _maximum_cpu
|
||||||
|
from .maximum_grad import _maximum_grad_cpu
|
||||||
|
from .conv2d import _conv2d_cpu
|
||||||
|
from .hsigmoid import _hsigmoid_cpu
|
||||||
|
from .hsigmoid_grad import _hsigmoid_grad_cpu
|
||||||
|
from .hswish import _hswish_cpu
|
||||||
|
from .hswish_grad import _hswish_grad_cpu
|
||||||
|
from .is_finite import _is_finite_cpu
|
||||||
|
from .layer_norm import _layer_norm_cpu
|
||||||
|
from .layer_norm_grad import _layer_norm_grad_cpu
|
||||||
|
from .minimum import _minimum_cpu
|
||||||
|
from .minimum_grad import _minimum_grad_cpu
|
||||||
|
from .equal_count import _equal_count_cpu
|
||||||
|
from .mirror_pad import _mirror_pad_cpu
|
||||||
|
from .mirror_pad_grad import _mirror_pad_grad_cpu
|
||||||
|
from .stack import _stack_cpu
|
||||||
|
from .reduce_mean import _reduce_mean_cpu
|
||||||
|
from .reduce_max import _reduce_max_cpu
|
||||||
|
from .reduce_sum import _reduce_sum_cpu
|
||||||
|
from .reduce_min import _reduce_min_cpu
|
||||||
|
from .reduce_all import _reduce_all_cpu
|
||||||
|
from .reduce_any import _reduce_any_cpu
|
||||||
|
from .transpose import _transpose_cpu
|
||||||
|
from .tile import _tile_cpu
|
||||||
|
from .top_k import _top_k_cpu
|
||||||
|
from .add import _add_cpu
|
||||||
|
from .one_hot import _one_hot_cpu
|
||||||
|
from .pad import _pad_cpu
|
||||||
|
from .range import _range_cpu
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Concat op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
adam_op_info = CpuRegOp("Adam") \
|
||||||
|
.input(0, "var", "required") \
|
||||||
|
.input(1, "m", "required") \
|
||||||
|
.input(2, "v", "required") \
|
||||||
|
.input(3, "beta1_power", "required") \
|
||||||
|
.input(4, "beta2_power", "required") \
|
||||||
|
.input(5, "lr", "required") \
|
||||||
|
.input(6, "beta1", "required") \
|
||||||
|
.input(7, "beta2", "required") \
|
||||||
|
.input(8, "epsilon", "required") \
|
||||||
|
.input(9, "gradient", "required") \
|
||||||
|
.output(0, "output0", "required") \
|
||||||
|
.output(1, "output0", "required") \
|
||||||
|
.output(2, "output0", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(adam_op_info)
|
||||||
|
def _adam_cpu():
|
||||||
|
"""Adam cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Add op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
add_op_info = CpuRegOp("Add") \
|
||||||
|
.input(0, "x1", "required") \
|
||||||
|
.input(1, "x2", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(add_op_info)
|
||||||
|
def _add_cpu():
|
||||||
|
"""Add cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Argmax op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
arg_max_op_info = CpuRegOp("Argmax") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(arg_max_op_info)
|
||||||
|
def _arg_max_cpu():
|
||||||
|
"""Argmax cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ArgMinWithValue op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
arg_min_with_value_op_info = CpuRegOp("ArgMinWithValue") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "indice", "required") \
|
||||||
|
.output(1, "values", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(arg_min_with_value_op_info)
|
||||||
|
def _arg_min_with_value_cpu():
|
||||||
|
"""ArgMinWithValue cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""BiasAdd op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
bias_add_op_info = CpuRegOp("BiasAdd") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "bias", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(bias_add_op_info)
|
||||||
|
def _bias_add_cpu():
|
||||||
|
"""BiasAdd cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""BiasAddGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
bias_add_grad_op_info = CpuRegOp("BiasAddGrad") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(bias_add_grad_op_info)
|
||||||
|
def _bias_add_grad_cpu():
|
||||||
|
"""BiasAddGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Cast op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
cast_op_info = CpuRegOp("Cast") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(cast_op_info)
|
||||||
|
def _cast_cpu():
|
||||||
|
"""Cast Cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Concat op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
concat_op_info = CpuRegOp("Concat") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(concat_op_info)
|
||||||
|
def _concat_cpu():
|
||||||
|
"""Concat cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Conv2D op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
conv2d_op_info = CpuRegOp("Conv2D") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "filter", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(conv2d_op_info)
|
||||||
|
def _conv2d_cpu():
|
||||||
|
"""Conv2D cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Div op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
div_op_info = CpuRegOp("Div") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "y", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(div_op_info)
|
||||||
|
def _div_cpu():
|
||||||
|
"""Div cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Dropout op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
dropout_op_info = CpuRegOp("Dropout") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "output0", "required") \
|
||||||
|
.output(1, "output1", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(dropout_op_info)
|
||||||
|
def _dropout_cpu():
|
||||||
|
"""Dropout cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""DropoutGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
dropout_grad_op_info = CpuRegOp("DropoutGrad") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "y", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(dropout_grad_op_info)
|
||||||
|
def _dropout_grad_cpu():
|
||||||
|
"""DropoutGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""EqualCount op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
equal_count_op_info = CpuRegOp("EqualCount") \
|
||||||
|
.input(0, "x1", "required") \
|
||||||
|
.input(1, "x2", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(equal_count_op_info)
|
||||||
|
def _equal_count_cpu():
|
||||||
|
"""EqualCount cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""GatherD op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
gather_d_op_info = CpuRegOp("GatherD") \
|
||||||
|
.input(0, "input", "required") \
|
||||||
|
.input(1, "dim", "required") \
|
||||||
|
.input(2, "index", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
|
||||||
|
DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(gather_d_op_info)
|
||||||
|
def _gather_cpu():
|
||||||
|
"""GatherD cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""GatherDGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
gather_d_grad_op_info = CpuRegOp("GatherDGrad") \
|
||||||
|
.input(0, "index", "required") \
|
||||||
|
.input(1, "src", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(gather_d_grad_op_info)
|
||||||
|
def _gather_d_grad_cpu():
|
||||||
|
"""GatherDGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""GatherNd op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
gather_nd_op_info = CpuRegOp("GatherNd") \
|
||||||
|
.input(0, "x1", "required") \
|
||||||
|
.input(1, "x2", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(gather_nd_op_info)
|
||||||
|
def _gather_nd_cpu():
|
||||||
|
"""GatherNd cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""GatherV2 op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
gather_v2_op_info = CpuRegOp("Gather") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "indices", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(gather_v2_op_info)
|
||||||
|
def _gather_v2_cpu():
|
||||||
|
"""GatherV2 cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""HSigmoid op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
hsigmoid_op_info = CpuRegOp("HSigmoid") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(hsigmoid_op_info)
|
||||||
|
def _hsigmoid_cpu():
|
||||||
|
"""HSigmoid cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""HSigmoidGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
hsigmoidgrad_op_info = CpuRegOp("HSigmoidGrad") \
|
||||||
|
.input(0, "y_grad") \
|
||||||
|
.input(1, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(hsigmoidgrad_op_info)
|
||||||
|
def _hsigmoid_grad_cpu():
|
||||||
|
"""HSigmoidGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""HSwish op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
hswish_op_info = CpuRegOp("HSwish") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(hswish_op_info)
|
||||||
|
def _hswish_cpu():
|
||||||
|
"""HSwish cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""HSwishGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
hswish_grad_op_info = CpuRegOp("HSwishGrad") \
|
||||||
|
.input(0, "y_grad") \
|
||||||
|
.input(1, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(hswish_grad_op_info)
|
||||||
|
def _hswish_grad_cpu():
|
||||||
|
"""HSwishGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""IsFinite op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
is_finite_op_info = CpuRegOp("IsFinite") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(is_finite_op_info)
|
||||||
|
def _is_finite_cpu():
|
||||||
|
"""IsFinite cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""LayerNorm op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
layer_norm_op_info = CpuRegOp("LayerNorm") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "gamma", "required") \
|
||||||
|
.input(2, "beta", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.output(1, "mean", "required") \
|
||||||
|
.output(2, "variance", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
||||||
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||||
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(layer_norm_op_info)
|
||||||
|
def _layer_norm_cpu():
|
||||||
|
"""LayerNorm cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""LayerNormGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
layer_norm_grad_op_info = CpuRegOp("LayerNormGrad") \
|
||||||
|
.input(0, "dy", "required") \
|
||||||
|
.input(1, "x", "required") \
|
||||||
|
.input(2, "variance", "required") \
|
||||||
|
.input(3, "mean", "required") \
|
||||||
|
.input(4, "gamma", "required") \
|
||||||
|
.output(0, "pd_x", "required") \
|
||||||
|
.output(1, "pd_gamma", "required") \
|
||||||
|
.output(2, "pd_beta", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(layer_norm_grad_op_info)
|
||||||
|
def _layer_norm_grad_cpu():
|
||||||
|
"""LayerNormGrad TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Maximum op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
maximum_op_info = CpuRegOp("Maximum") \
|
||||||
|
.input(0, "x1", "required") \
|
||||||
|
.input(1, "x2", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(maximum_op_info)
|
||||||
|
def _maximum_cpu():
|
||||||
|
"""Maximum cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MaximumGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
maximum_grad_op_info = CpuRegOp("MaximumGrad") \
|
||||||
|
.input(0, "grads", "required") \
|
||||||
|
.input(1, "x1", "required") \
|
||||||
|
.input(2, "x2", "required") \
|
||||||
|
.output(0, "y1", "required") \
|
||||||
|
.output(1, "y2", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default,
|
||||||
|
DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default,
|
||||||
|
DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default,
|
||||||
|
DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(maximum_grad_op_info)
|
||||||
|
def _maximum_grad_cpu():
|
||||||
|
"""MaximumGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
"""Minimum op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
minimum_op_info = CpuRegOp("Minimum") \
|
||||||
|
.input(0, "x1", "required") \
|
||||||
|
.input(1, "x2", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(minimum_op_info)
|
||||||
|
def _minimum_cpu():
|
||||||
|
"""Minimum cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MinimumGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
minimum_grad_op_info = CpuRegOp("MinimumGrad") \
|
||||||
|
.input(0, "grads", "required") \
|
||||||
|
.input(1, "x1", "required") \
|
||||||
|
.input(2, "x2", "required") \
|
||||||
|
.output(0, "y1", "required") \
|
||||||
|
.output(1, "y2", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
||||||
|
DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||||
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(minimum_grad_op_info)
|
||||||
|
def _minimum_grad_cpu():
|
||||||
|
"""MinimumGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MirrorPad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
mirror_pad_op_info = CpuRegOp("MirrorPad") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "paddings", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(mirror_pad_op_info)
|
||||||
|
def _mirror_pad_cpu():
|
||||||
|
"""MirrorPad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MirrorPadGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
mirror_pad_grad_op_info = CpuRegOp("MirrorPadGrad") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "paddings", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(mirror_pad_grad_op_info)
|
||||||
|
def _mirror_pad_grad_cpu():
|
||||||
|
"""MirrorPadGrad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Mul op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
mul_op_info = CpuRegOp("Mul") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "y", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(mul_op_info)
|
||||||
|
def _mul_cpu():
|
||||||
|
"""Mul cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""OneHot op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
one_hot_op_info = CpuRegOp("OneHot") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "on_value", "required") \
|
||||||
|
.input(2, "off_value", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(one_hot_op_info)
|
||||||
|
def _one_hot_cpu():
|
||||||
|
"""OneHot cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Pad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
pad_op_info = CpuRegOp("Pad") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(pad_op_info)
|
||||||
|
def _pad_cpu():
|
||||||
|
"""Pad cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Power op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
pow_op_info = CpuRegOp("Pow") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "y", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(pow_op_info)
|
||||||
|
def _pow_cpu():
|
||||||
|
"""Pow cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Range op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
range_op_info = CpuRegOp("Range") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(range_op_info)
|
||||||
|
def _range_cpu():
|
||||||
|
"""Range cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""RealDiv op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
real_div_op_info = CpuRegOp("RealDiv") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "y", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(real_div_op_info)
|
||||||
|
def _real_div_cpu():
|
||||||
|
"""RealDiv cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceAll op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
reduce_all_op_info = CpuRegOp("ReduceAll") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reduce_all_op_info)
|
||||||
|
def _reduce_all_cpu():
|
||||||
|
"""ReduceAll cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceAny op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
reduce_any_op_info = CpuRegOp("ReduceAny") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reduce_any_op_info)
|
||||||
|
def _reduce_any_cpu():
|
||||||
|
"""ReduceAny cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceMax op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
reduce_max_op_info = CpuRegOp("ReduceMax") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reduce_max_op_info)
|
||||||
|
def _reduce_max_cpu():
|
||||||
|
"""ReduceMax cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceMean op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
reduce_mean_op_info = CpuRegOp("ReduceMean") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reduce_mean_op_info)
|
||||||
|
def _reduce_mean_cpu():
|
||||||
|
"""ReduceMean cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceMin op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
reduce_min_op_info = CpuRegOp("ReduceMin") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reduce_min_op_info)
|
||||||
|
def _reduce_min_cpu():
|
||||||
|
"""ReduceMin cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceSum op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
reduce_sum_op_info = CpuRegOp("ReduceSum") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reduce_sum_op_info)
|
||||||
|
def _reduce_sum_cpu():
|
||||||
|
"""ReduceSum cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Split op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
split_op_info = CpuRegOp("Split") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(split_op_info)
|
||||||
|
def _split_cpu():
|
||||||
|
"""Split cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Stack op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
stack_op_info = CpuRegOp("Stack") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(stack_op_info)
|
||||||
|
def _stack_cpu():
|
||||||
|
"""Stack cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Sub op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
sub_op_info = CpuRegOp("Sub") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "y", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(sub_op_info)
|
||||||
|
def _sub_cpu():
|
||||||
|
"""Sub cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Tile op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
tile_op_info = CpuRegOp("Tile") \
|
||||||
|
.input(0, "x1", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(tile_op_info)
|
||||||
|
def _tile_cpu():
|
||||||
|
"""Tile cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""TopK op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
top_k_op_info = CpuRegOp("TopK") \
|
||||||
|
.input(0, "input", "required") \
|
||||||
|
.input(1, "k", "required") \
|
||||||
|
.output(0, "values", "required") \
|
||||||
|
.output(1, "indices", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(top_k_op_info)
|
||||||
|
def _top_k_cpu():
|
||||||
|
"""TopK cpu register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Transpose op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||||
|
|
||||||
|
transpose_op_info = CpuRegOp("Transpose") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(transpose_op_info)
|
||||||
|
def _transpose_cpu():
|
||||||
|
"""TransposeD cpu register"""
|
||||||
|
return
|
|
@ -213,6 +213,65 @@ class RegOp:
|
||||||
return op_info
|
return op_info
|
||||||
|
|
||||||
|
|
||||||
|
class CpuRegOp(RegOp):
|
||||||
|
"""Class for Cpu op info register"""
|
||||||
|
|
||||||
|
def __init__(self, op_name):
|
||||||
|
super(CpuRegOp, self).__init__(op_name)
|
||||||
|
self.imply_type = "CPU"
|
||||||
|
|
||||||
|
def input(self, index=None, name=None, param_type=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Register Cpu op input information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index (int): Order of the input. Default: None.
|
||||||
|
name (str): Name of the input. Default: None.
|
||||||
|
param_type (str): Param type of the input. Default: None.
|
||||||
|
kwargs (dict): Other information of the input.
|
||||||
|
"""
|
||||||
|
param_list = [index, name, param_type]
|
||||||
|
key_list = ["index", "name", "param_type"]
|
||||||
|
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||||
|
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||||
|
self.inputs.append(input_dict)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def output(self, index=None, name=None, param_type=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Register AiCPU op output information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index (int): Order of the output. Default: None.
|
||||||
|
name (str): Name of the output. Default: None.
|
||||||
|
param_type (str): Param type of the output. Default: None.
|
||||||
|
kwargs (dict): Other information of the output.
|
||||||
|
"""
|
||||||
|
param_list = [index, name, param_type]
|
||||||
|
key_list = ["index", "name", "param_type"]
|
||||||
|
fn_list = [self._is_int, self._is_string, self._is_string]
|
||||||
|
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||||
|
self.outputs.append(output_dict)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def attr(self, name=None, value_type=None, value=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Register AiCPU op attribute information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the attribute. Default: None.
|
||||||
|
value_type (str): Value type of the attribute. Default: None.
|
||||||
|
value (str): Value of the attribute. Default: None.
|
||||||
|
kwargs (dict): Other information of the attribute.
|
||||||
|
"""
|
||||||
|
param_list = [name, value_type, value]
|
||||||
|
key_list = ["name", "type", "value"]
|
||||||
|
fn_list = [self._is_string]
|
||||||
|
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
|
||||||
|
self.attr_.append(attr_dict)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class AkgRegOp(RegOp):
|
class AkgRegOp(RegOp):
|
||||||
"""Class for Akg op info register."""
|
"""Class for Akg op info register."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue