!19742 add cpu masked_select type

Merge pull request !19742 from baihuawei/mask1.3
This commit is contained in:
i-robot 2021-07-10 08:21:00 +00:00 committed by Gitee
commit 9e4cd7121f
3 changed files with 54 additions and 1 deletions

View File

@ -48,10 +48,30 @@ MS_REG_CPU_KERNEL_T(
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32),
MaskedSelectCPUKernel, float);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16),
MaskedSelectCPUKernel, float16);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64),
MaskedSelectCPUKernel, double);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
MaskedSelectCPUKernel, int);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16),
MaskedSelectCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
MaskedSelectCPUKernel, int64_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_

View File

@ -51,6 +51,22 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
.AddOutputAttr(kNumberTypeFloat32),
MaskedSelectGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MaskedSelectGradCPUKernel, double);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
MaskedSelectGradCPUKernel, float16);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
@ -58,6 +74,22 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MaskedSelectGradCPUKernel, int);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MaskedSelectGradCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
MaskedSelectGradCPUKernel, int16_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_

View File

@ -676,7 +676,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
try {
ModelRunner::Instance().RunModel(graph->graph_id());
} catch (const std::exception &) {
} catch (const std::exception &e) {
DumpTaskExceptionInfo(graph);
std::string file_name = "task_error_debug" + std::to_string(graph->graph_id()) + ".ir";
auto graph_tmp = std::make_shared<session::KernelGraph>(*graph);
@ -690,6 +690,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
MS_LOG(INFO) << "Destroy tdt channel success.";
}
#endif
MS_LOG(ERROR) << "RunModel error msg: " << e.what();
return false;
}
task_fail_infoes_.clear();