!19742 add cpu masked_select type
Merge pull request !19742 from baihuawei/mask1.3
This commit is contained in:
commit
9e4cd7121f
|
@ -48,10 +48,30 @@ MS_REG_CPU_KERNEL_T(
|
|||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32),
|
||||
MaskedSelectCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16),
|
||||
MaskedSelectCPUKernel, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64),
|
||||
MaskedSelectCPUKernel, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||
MaskedSelectCPUKernel, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16),
|
||||
MaskedSelectCPUKernel, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
|
||||
MaskedSelectCPUKernel, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
|
||||
|
|
|
@ -51,6 +51,22 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
|||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MaskedSelectGradCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MaskedSelectGradCPUKernel, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MaskedSelectGradCPUKernel, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
@ -58,6 +74,22 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
|||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaskedSelectGradCPUKernel, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MaskedSelectGradCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
MaskedSelectGradCPUKernel, int16_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -676,7 +676,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
|
||||
try {
|
||||
ModelRunner::Instance().RunModel(graph->graph_id());
|
||||
} catch (const std::exception &) {
|
||||
} catch (const std::exception &e) {
|
||||
DumpTaskExceptionInfo(graph);
|
||||
std::string file_name = "task_error_debug" + std::to_string(graph->graph_id()) + ".ir";
|
||||
auto graph_tmp = std::make_shared<session::KernelGraph>(*graph);
|
||||
|
@ -690,6 +690,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
MS_LOG(INFO) << "Destroy tdt channel success.";
|
||||
}
|
||||
#endif
|
||||
MS_LOG(ERROR) << "RunModel error msg: " << e.what();
|
||||
return false;
|
||||
}
|
||||
task_fail_infoes_.clear();
|
||||
|
|
Loading…
Reference in New Issue