From 60e4c1183fb8c77417b676eaa0e1a17b2452a498 Mon Sep 17 00:00:00 2001 From: hanhuifeng2020 Date: Thu, 7 Dec 2023 14:58:39 +0800 Subject: [PATCH] refractor some python enums in yaml to C++ --- .gitignore | 2 - .../graph_kernel/convert_input_and_attr.cc | 8 +- .../include/common/pybind_api/api_register.h | 4 + mindspore/ccsrc/pipeline/jit/ps/init.cc | 1 + .../pipeline/jit/ps/static_analysis/prim.cc | 5 + .../cpu/kernel/grid_sampler_2d_cpu_kernel.cc | 14 +- .../kernel/grid_sampler_2d_grad_cpu_kernel.cc | 10 +- .../cpu/kernel/grid_sampler_3d_cpu_kernel.cc | 10 +- .../kernel/grid_sampler_3d_grad_cpu_kernel.cc | 10 +- .../device/cpu/kernel/nllloss_cpu_kernel.cc | 8 +- .../device/cpu/kernel/nllloss_cpu_kernel.h | 4 +- .../cpu/kernel/nllloss_grad_cpu_kernel.cc | 6 +- .../cpu/kernel/nllloss_grad_cpu_kernel.h | 4 +- .../cpu/kernel/resize_linear_1d_cpu_kernel.cc | 14 +- .../cpu/kernel/resize_linear_1d_cpu_kernel.h | 7 +- .../resize_linear_1d_grad_cpu_kernel.cc | 14 +- .../kernel/resize_linear_1d_grad_cpu_kernel.h | 7 +- .../cuda_ops/loss_with_reduction_impl.cuh | 6 + .../kernel/math/fft_with_size_gpu_kernel.cc | 12 +- .../kernel/math/fft_with_size_gpu_kernel.h | 4 +- .../gpu/kernel/nn/grid_sampler_gpu_kernel.h | 2 +- .../kernel/nn/grid_sampler_grad_gpu_kernel.h | 2 +- .../gpu/kernel/nn/nll_loss_gpu_kernel.cc | 2 +- .../gpu/kernel/nn/nll_loss_grad_gpu_kernel.cc | 2 +- .../kernel/nn/resize_linear_1d_gpu_kernel.cc | 7 +- .../nn/resize_linear_1d_grad_gpu_kernel.cc | 7 +- .../optimizer/bce_with_logits_loss_fusion.cc | 1 - .../ccsrc/pybind_api/utils/op_enum_py.cc | 73 +++++ .../transform/graph_ir/op_adapter_base.h | 2 +- mindspore/core/mindapi/base/format.h | 7 +- mindspore/core/mindapi/base/type_id.h | 2 - mindspore/core/ops/op_enum.cc | 164 +++++++++++ mindspore/core/ops/op_enum.h | 49 ++++ mindspore/core/ops/ops_def/argmax_op.yaml | 4 +- mindspore/core/ops/ops_def/argmin_op.yaml | 4 +- .../core/ops/ops_def/avg_pool_grad_op.yaml | 4 +- mindspore/core/ops/ops_def/avg_pool_op.yaml | 4 +- .../ops/ops_def/batch_norm_grad_grad_op.yaml | 2 +- .../core/ops/ops_def/batch_norm_grad_op.yaml | 2 +- .../batch_norm_grad_with_activation_op.yaml | 2 +- ..._norm_grad_with_add_and_activation_op.yaml | 2 +- mindspore/core/ops/ops_def/batch_norm_op.yaml | 2 +- .../batch_norm_with_activation_op.yaml | 2 +- ...batch_norm_with_add_and_activation_op.yaml | 2 +- .../core/ops/ops_def/bias_add_grad_op.yaml | 2 +- mindspore/core/ops/ops_def/bias_add_op.yaml | 2 +- mindspore/core/ops/ops_def/eye_op.yaml | 4 +- .../core/ops/ops_def/fft_with_size_op.yaml | 2 +- .../ops/ops_def/grid_sampler_2d_grad_op.yaml | 4 +- .../core/ops/ops_def/grid_sampler_2d_op.yaml | 4 +- .../ops/ops_def/grid_sampler_3d_grad_op.yaml | 4 +- .../core/ops/ops_def/grid_sampler_3d_op.yaml | 6 +- .../inner/extract_image_patches_op.yaml | 2 +- .../ops_def/inner/prompt_k_v_cache_op.yaml | 2 +- .../core/ops/ops_def/nllloss_grad_op.yaml | 2 +- mindspore/core/ops/ops_def/nllloss_op.yaml | 2 +- .../core/ops/ops_def/randperm_v2_op.yaml | 4 +- .../ops/ops_def/resize_linear_1d_grad_op.yaml | 2 +- .../core/ops/ops_def/resize_linear_1d_op.yaml | 2 +- mindspore/core/ops/ops_func_impl/eig.cc | 7 +- .../core/ops/ops_func_impl/fft_with_size.cc | 1 - mindspore/core/ops/ops_func_impl/nllloss.cc | 8 +- mindspore/core/ops/other_ops.h | 3 + mindspore/core/ops/string_to_enum.cc | 65 +++++ mindspore/core/ops/string_to_enum.h | 39 +++ .../conv_bias_fusion_inout_test.cc | 1 - .../mindspore/ops/_vmap/vmap_grad_nn_ops.py | 6 +- .../python/mindspore/ops/_vmap/vmap_nn_ops.py | 8 +- .../mindspore/ops/auto_generate/__init__.py | 3 +- .../operations/manually_defined/ops_def.py | 2 +- .../mindspore/ops_generate/arg_dtype_cast.py | 113 ++++---- .../mindspore/ops_generate/arg_handler.py | 15 +- .../python/mindspore/ops_generate/enum.yaml | 241 ----------------- .../python/mindspore/ops_generate/gen_ops.py | 98 ++----- .../ops_generate/gen_ops_inner_prim.py | 26 ++ .../mindspore/ops_generate/gen_utils.py | 46 ++-- tests/ut/cpp/ops/test_ops_fft_with_size.cc | 4 +- tests/ut/cpp/ops/test_ops_nllloss.cc | 256 ++++++++++++++++-- 78 files changed, 913 insertions(+), 573 deletions(-) create mode 100644 mindspore/ccsrc/pybind_api/utils/op_enum_py.cc create mode 100644 mindspore/core/ops/op_enum.cc create mode 100644 mindspore/core/ops/op_enum.h create mode 100644 mindspore/core/ops/string_to_enum.cc create mode 100644 mindspore/core/ops/string_to_enum.h delete mode 100644 mindspore/python/mindspore/ops_generate/enum.yaml diff --git a/.gitignore b/.gitignore index 34b7e2fee5a..13c35944be4 100644 --- a/.gitignore +++ b/.gitignore @@ -119,7 +119,6 @@ mindspore/core/ops/auto_generate/gen_lite_ops.h mindspore/core/ops/auto_generate/gen_lite_ops.cc mindspore/core/ops/auto_generate/gen_ops_name.h mindspore/core/ops/auto_generate/gen_ops_primitive.h -mindspore/core/ops/auto_generate/gen_enum_def.h mindspore/core/ops/auto_generate/gen_ops_def.cc mindspore/core/ops/auto_generate/gen_ops_def.h mindspore/python/mindspore/ops_generate/ops.yaml @@ -129,7 +128,6 @@ mindspore/python/mindspore/ops_generate/inner_ops_doc.yaml mindspore/python/mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py mindspore/python/mindspore/ops/auto_generate/gen_arg_handler.py mindspore/python/mindspore/ops/auto_generate/gen_arg_dtype_cast.py -mindspore/python/mindspore/ops/auto_generate/gen_enum_def.py mindspore/python/mindspore/ops/auto_generate/gen_inner_ops_def.py mindspore/python/mindspore/ops/auto_generate/gen_ops_def.py mindspore/python/mindspore/ops/auto_generate/gen_pyboost_func.py diff --git a/mindspore/ccsrc/backend/common/graph_kernel/convert_input_and_attr.cc b/mindspore/ccsrc/backend/common/graph_kernel/convert_input_and_attr.cc index f39907d9601..eb871c40cbb 100644 --- a/mindspore/ccsrc/backend/common/graph_kernel/convert_input_and_attr.cc +++ b/mindspore/ccsrc/backend/common/graph_kernel/convert_input_and_attr.cc @@ -107,8 +107,8 @@ using ArgHandlerFunc = std::function; ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) { static const std::unordered_map arg_handler_funcs = { - {"py_format_to_enum", EnumToFormat}, - {"dtype_to_enum", EnumToDtype}, + {"str_to_enum", EnumToFormat}, + {"dtype_to_type_id", EnumToDtype}, }; if (arg_handler_funcs.find(arg_handler) != arg_handler_funcs.end()) { return arg_handler_funcs.at(arg_handler); @@ -119,8 +119,8 @@ ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) { ArgHandlerFunc GetOppArgHandlerFunc(const std::string &arg_handler) { static const std::unordered_map opp_arg_handler_funcs = { - {"py_format_to_enum", FormatToEnum}, - {"dtype_to_enum", DtypeToEnum}, + {"str_to_enum", FormatToEnum}, + {"dtype_to_type_id", DtypeToEnum}, }; if (opp_arg_handler_funcs.find(arg_handler) != opp_arg_handler_funcs.end()) { return opp_arg_handler_funcs.at(arg_handler); diff --git a/mindspore/ccsrc/include/common/pybind_api/api_register.h b/mindspore/ccsrc/include/common/pybind_api/api_register.h index e80178d31e3..a7982975af3 100644 --- a/mindspore/ccsrc/include/common/pybind_api/api_register.h +++ b/mindspore/ccsrc/include/common/pybind_api/api_register.h @@ -70,6 +70,10 @@ namespace abstract { void RegPrimitiveFrontEval(); } #endif + +namespace ops { +void RegOpEnum(py::module *m); +} // namespace ops } // namespace mindspore #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_PYBIND_API_API_REGISTER_H_ diff --git a/mindspore/ccsrc/pipeline/jit/ps/init.cc b/mindspore/ccsrc/pipeline/jit/ps/init.cc index b70c4b79aac..6f044c1bd1c 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/init.cc +++ b/mindspore/ccsrc/pipeline/jit/ps/init.cc @@ -156,6 +156,7 @@ void RegModule(py::module *m) { mindspore::abstract::RegPrimitiveFrontEval(); #endif mindspore::expander::RegPackExpanderPy(m); + mindspore::ops::RegOpEnum(m); } void RegModuleHelper(py::module *m) { diff --git a/mindspore/ccsrc/pipeline/jit/ps/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/ps/static_analysis/prim.cc index a9aad3b67dc..81c1924b51f 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/ps/static_analysis/prim.cc @@ -137,6 +137,11 @@ AnfNodePtr GetNodeAfterArgHandler(const AnfNodePtr &node, const ops::OpInputArg return node; } const auto arg_handler_func = prim::GetPythonOps(op_arg.arg_handler_, parse::PYTHON_MOD_PRIMITIVE_ARG_HANDLER_MODULE); + if (arg_handler_func->isa()) { + auto arg_handler_fg = dyn_cast(arg_handler_func); + MS_EXCEPTION_IF_NULL(arg_handler_fg); + return fg->NewCNodeInOrder({NewValueNode(arg_handler_fg), node}); + } auto arg_handler_fg = dyn_cast(arg_handler_func); MS_EXCEPTION_IF_NULL(arg_handler_fg); arg_handler_fg->set_manager(fg->manager()); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_cpu_kernel.cc index 960b2fcb6c6..5a7c835808d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_cpu_kernel.cc @@ -16,7 +16,7 @@ #include "plugin/device/cpu/kernel/grid_sampler_2d_cpu_kernel.h" #include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" +#include "mindspore/core/ops/op_enum.h" namespace { const size_t kDataSizeThreshold = 64 * 1024; @@ -119,7 +119,7 @@ void GridSampler2DCpuKernelMod::ComputeTask(const T *x_addr, const T *grid_addr, auto x_ptr_NC = out_iter[kZero] * x_stride_[kZero]; auto output_ptr_NCDHW = out_iter[kZero] * output_stride_[kZero] + out_iter[kTwo] * output_stride_[kTwo] + out_iter[kThree] * output_stride_[kThree]; - if (interpolation_mode_ == static_cast(MsPyEnum::InterpolationMode::BILINEAR)) { + if (interpolation_mode_ == static_cast(ops::InterpolationMode::BILINEAR)) { int64_t x_tnw = static_cast(std::floor(x)); int64_t y_tnw = static_cast(std::floor(y)); int64_t x_tne = x_tnw + 1; @@ -153,7 +153,7 @@ void GridSampler2DCpuKernelMod::ComputeTask(const T *x_addr, const T *grid_addr, output_addr[output_ptr_NCDHW] += x_addr[x_index] * tse; } } - } else if (interpolation_mode_ == static_cast(MsPyEnum::InterpolationMode::NEAREST)) { + } else if (interpolation_mode_ == static_cast(ops::InterpolationMode::NEAREST)) { int64_t x_nearest = static_cast(std::round(x)); int64_t y_nearest = static_cast(std::round(y)); for (size_t c = 0; c < out_c; c++, x_ptr_NC += x_stride_[kOne], output_ptr_NCDHW += output_stride_[kOne]) { @@ -221,9 +221,9 @@ void GridSampler2DCpuKernelMod::Call2Half(const float16 *x_data_addr, float16 *y auto x_ptr_NC = y_iter[0] * x_stride[0]; auto y_ptr_NCHW = y_iter[0] * y_stride[0] + y_iter[2] * y_stride[2] + y_iter[3] * y_stride[3]; - if (interpolation_mode == static_cast(MsPyEnum::InterpolationMode::BILINEAR)) { + if (interpolation_mode == static_cast(ops::InterpolationMode::BILINEAR)) { BilinearHalf(x, y, x_data_addr, y_data_addr, y_c, x_dims, y_stride, x_stride, x_ptr_NC, y_ptr_NCHW); - } else if (interpolation_mode == static_cast(MsPyEnum::InterpolationMode::NEAREST)) { + } else if (interpolation_mode == static_cast(ops::InterpolationMode::NEAREST)) { NearestHalf(x, y, x_data_addr, y_data_addr, y_c, x_dims, y_stride, x_stride, x_ptr_NC, y_ptr_NCHW); } } @@ -333,9 +333,9 @@ T GridSampler2DCpuKernelMod::GridSamplerComputeSourceIndex(T coord, int64_t size } else { coord = ((coord + 1.f) * size - 1) / num2; } - if (padding_mode == static_cast(MsPyEnum::GridSamplerPaddingMode::BORDER)) { + if (padding_mode == static_cast(ops::GridSamplerPaddingMode::BORDER)) { coord = std::min(static_cast(size - 1), std::max(coord, static_cast(0))); - } else if (padding_mode == static_cast(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) { + } else if (padding_mode == static_cast(ops::GridSamplerPaddingMode::REFLECTION)) { if (align_corners) { coord = ReflectCoordinates(coord, 0, num2 * (size - 1)); } else { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_grad_cpu_kernel.cc index da7850fcf94..1108f47b142 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_2d_grad_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "plugin/device/cpu/kernel/grid_sampler_2d_grad_cpu_kernel.h" #include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" +#include "mindspore/core/ops/op_enum.h" namespace { const size_t kDataSizeThreshold = 64 * 1024; @@ -93,14 +93,14 @@ void GridSampler2DGradCpuKernelMod::ComputeTask(const std::vector(inputs[kTwo]->device_ptr()); auto dx_data_addr = static_cast(outputs[kZero]->device_ptr()); auto dgrid_data_addr = static_cast(outputs[kOne]->device_ptr()); - if (interpolation_mode_ == static_cast(MsPyEnum::InterpolationMode::BILINEAR)) { + if (interpolation_mode_ == static_cast(ops::InterpolationMode::BILINEAR)) { interp = GridSamplerInterpolation::Bilinear; - } else if (interpolation_mode_ == static_cast(MsPyEnum::InterpolationMode::NEAREST)) { + } else if (interpolation_mode_ == static_cast(ops::InterpolationMode::NEAREST)) { interp = GridSamplerInterpolation::Nearest; } - if (padding_mode_ == static_cast(MsPyEnum::GridSamplerPaddingMode::ZEROS)) { + if (padding_mode_ == static_cast(ops::GridSamplerPaddingMode::ZEROS)) { padding = GridSamplerPadding::Zeros; - } else if (padding_mode_ == static_cast(MsPyEnum::GridSamplerPaddingMode::BORDER)) { + } else if (padding_mode_ == static_cast(ops::GridSamplerPaddingMode::BORDER)) { padding = GridSamplerPadding::Border; } else { padding = GridSamplerPadding::Reflection; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_cpu_kernel.cc index 7d6ff408ad8..fe290ad1ff0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "plugin/device/cpu/kernel/grid_sampler_3d_cpu_kernel.h" #include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" +#include "mindspore/core/ops/op_enum.h" namespace { const size_t kDataSizeThreshold = 64 * 1024; @@ -117,7 +117,7 @@ void GridSampler3DCpuKernelMod::ComputeTask(T *x_addr, T *grid_addr, T *output_a auto x_ptr_NC = out_iter[kZero] * x_stride_[kZero]; auto output_ptr_NCDHW = out_iter[kZero] * output_stride_[kZero] + out_iter[kTwo] * output_stride_[kTwo] + out_iter[kThree] * output_stride_[kThree] + out_iter[kFour] * output_stride_[kFour]; - if (interpolation_mode_ == static_cast(MsPyEnum::InterpolationMode::BILINEAR)) { + if (interpolation_mode_ == static_cast(ops::InterpolationMode::BILINEAR)) { int64_t x_tnw = static_cast(std::floor(x)); int64_t y_tnw = static_cast(std::floor(y)); int64_t z_tnw = static_cast(std::floor(z)); @@ -176,7 +176,7 @@ void GridSampler3DCpuKernelMod::ComputeTask(T *x_addr, T *grid_addr, T *output_a output_addr[output_ptr_NCDHW] += x_addr[x_index] * bse; } } - } else if (interpolation_mode_ == static_cast(MsPyEnum::InterpolationMode::NEAREST)) { + } else if (interpolation_mode_ == static_cast(ops::InterpolationMode::NEAREST)) { int64_t x_nearest = static_cast(std::round(x)); int64_t y_nearest = static_cast(std::round(y)); int64_t z_nearest = static_cast(std::round(z)); @@ -220,9 +220,9 @@ T GridSampler3DCpuKernelMod::grid_sampler_compute_source_index(T coord, int64_t } else { coord = ((coord + 1.f) * size - kOne) / kTwo; } - if (padding_mode == static_cast(MsPyEnum::GridSamplerPaddingMode::BORDER)) { + if (padding_mode == static_cast(ops::GridSamplerPaddingMode::BORDER)) { coord = std::min(static_cast(static_cast(size) - kOne), std::max(coord, static_cast(kZero))); - } else if (padding_mode == static_cast(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) { + } else if (padding_mode == static_cast(ops::GridSamplerPaddingMode::REFLECTION)) { if (align_corners) { coord = reflect_coordinates(coord, static_cast(kZero), kTwo * (size - kOne)); } else { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_grad_cpu_kernel.cc index 2b4e6348f47..cc529b128e3 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/grid_sampler_3d_grad_cpu_kernel.cc @@ -15,7 +15,7 @@ */ #include "plugin/device/cpu/kernel/grid_sampler_3d_grad_cpu_kernel.h" #include "plugin/device/cpu/hal/device/cpu_device_address.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" +#include "mindspore/core/ops/op_enum.h" namespace { const size_t kDataSizeThreshold = 64 * 1024; @@ -214,7 +214,7 @@ void GridSampler3DGradCpuKernelMod::ComputeTask(T *grad_addr, T *x_addr, T *grid x = grid_sampler_compute_source_index_set_grad(x, x_shape_[kFour], padding_mode, align_corners_, &gx_mult); y = grid_sampler_compute_source_index_set_grad(y, x_shape_[kThree], padding_mode, align_corners_, &gy_mult); z = grid_sampler_compute_source_index_set_grad(z, x_shape_[kTwo], padding_mode, align_corners_, &gz_mult); - if (interpolation_mode == static_cast(MsPyEnum::InterpolationMode::BILINEAR)) { + if (interpolation_mode == static_cast(ops::InterpolationMode::BILINEAR)) { size_t grad_ptr_NCDHW = n * grad_stride_[kZero] + d * grad_stride_[kTwo] + h * grad_stride_[kThree] + w * grad_stride_[kFour]; size_t dx_ptr_NC = n * dx_stride_[kZero], x_ptr_NC = n * x_stride_[kZero]; @@ -223,7 +223,7 @@ void GridSampler3DGradCpuKernelMod::ComputeTask(T *grad_addr, T *x_addr, T *grid std::vector mult = {gx_mult, gy_mult, gz_mult}; std::vector ptr = {grad_ptr_NCDHW, x_ptr_NC, dx_ptr_NC, dgrid_ptr_NDHW}; BilinearKernel(addr, location, mult, ptr); - } else if (interpolation_mode == static_cast(MsPyEnum::InterpolationMode::NEAREST)) { + } else if (interpolation_mode == static_cast(ops::InterpolationMode::NEAREST)) { int64_t x_nearest = static_cast(std::round(x)); int64_t y_nearest = static_cast(std::round(y)); int64_t z_nearest = static_cast(std::round(z)); @@ -281,10 +281,10 @@ T GridSampler3DGradCpuKernelMod::grid_sampler_compute_source_index_set_grad(T co *grad_x = static_cast(size) / kTwo; coord = ((coord + kOne) * size - kOne) / kTwo; } - if (padding_mode == static_cast(MsPyEnum::GridSamplerPaddingMode::BORDER)) { + if (padding_mode == static_cast(ops::GridSamplerPaddingMode::BORDER)) { coord = clip_coordinates_set_grad(coord, size, &grad_clip); *grad_x = (*grad_x) * grad_clip; - } else if (padding_mode == static_cast(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) { + } else if (padding_mode == static_cast(ops::GridSamplerPaddingMode::REFLECTION)) { if (align_corners) { coord = reflect_coordinates_set_grad(coord, 0, (size - 1) * static_cast(kTwo), &grad_refl); } else { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.cc index d69aabb0d1b..bf960f941c4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.cc @@ -48,7 +48,7 @@ int NLLLossCpuKernelMod::Resize(const std::vector &inputs, const return ret; } auto reduction = inputs[kReductionIdx]->GetValueWithCheck(); - reduction_type_ = static_cast(reduction); + reduction_type_ = static_cast(reduction); ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck(); auto logits_shape = inputs[kIndex0]->GetShapeVector(); @@ -92,15 +92,15 @@ bool NLLLossCpuKernelMod::LaunchKernel(const std::vector float n_loss = -logits[index] * n_weight; tmp_total_weight += n_weight; total_loss += n_loss; - if (reduction_type_ == MsPyEnum::Reduction::NONE) { + if (reduction_type_ == Reduction::NONE) { loss[i] = n_loss; } } *total_weight = tmp_total_weight; - if (reduction_type_ == MsPyEnum::Reduction::SUM) { + if (reduction_type_ == Reduction::REDUCTION_SUM) { *loss = total_loss; - } else if (reduction_type_ == MsPyEnum::Reduction::MEAN) { + } else if (reduction_type_ == Reduction::MEAN) { *loss = total_loss / tmp_total_weight; } return true; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.h index 34c6702685a..b542625eac9 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_cpu_kernel.h @@ -23,7 +23,7 @@ #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" #include "nnacl/kernel/nllloss.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" +#include "mindapi/base/types.h" namespace mindspore { namespace kernel { @@ -55,7 +55,7 @@ class NLLLossCpuKernelMod : public NativeCpuKernelMod { NLLLossFunc kernel_func_; static std::vector> func_list_; NLLLossStruct nllloss_param_{}; - MsPyEnum::Reduction reduction_type_; + Reduction reduction_type_; int64_t ignore_index_; }; } // namespace kernel diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_grad_cpu_kernel.cc index ebfabe16279..e2c11ff27a4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nllloss_grad_cpu_kernel.cc @@ -49,7 +49,7 @@ int NLLLossGradCpuKernelMod::Resize(const std::vector &inputs, return ret; } auto reduction = inputs[kReductionIdx]->GetValueWithCheck(); - reduction_type_ = static_cast(reduction); + reduction_type_ = static_cast(reduction); ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck(); auto logits_shape = inputs[0]->GetShapeVector(); nllloss_param_.batch_ = LongToInt(logits_shape[0]); @@ -89,9 +89,9 @@ bool NLLLossGradCpuKernelMod::LaunchKernel(const std::vector ResizeLinear1DCpuKernelMod::CoordinateTransformationFunc ResizeLinear1DCpuKernelMod::ChooseCoordinateTransformationFunc( - MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) const { - const std::unordered_map> coordinate_map{ - {MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS, AlignCornersFunc()}, - {MsPyEnum::CoordinateTransformationMode::HALF_PIXEL, HalfPixelFunc()}}; + CoordinateTransformMode coordinate_transformation_mode) const { + const std::unordered_map> coordinate_map{ + {CoordinateTransformMode::ALIGN_CORNERS, AlignCornersFunc()}, + {CoordinateTransformMode::HALF_PIXEL, HalfPixelFunc()}}; return coordinate_map.at(coordinate_transformation_mode); } @@ -162,9 +162,9 @@ int ResizeLinear1DCpuKernelMod::Resize(const std::vector &inputs out_width_ = LongToSize(output_shape[kIndex2]); coordinate_transformation_mode_ = - static_cast(inputs.at(kIndex2)->GetValueWithCheck()); - if (coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS && - coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::HALF_PIXEL) { + static_cast(inputs.at(kIndex2)->GetValueWithCheck()); + if (coordinate_transformation_mode_ != CoordinateTransformMode::ALIGN_CORNERS && + coordinate_transformation_mode_ != CoordinateTransformMode::HALF_PIXEL) { MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now."; } SetWorkSpaceSize(inputs); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_cpu_kernel.h index 6da1759dacd..8e08a68d488 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_cpu_kernel.h @@ -24,7 +24,7 @@ #include #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" -#include "ops/auto_generate/gen_enum_def.h" +#include "mindapi/base/types.h" namespace mindspore::kernel { constexpr auto kUnknown = "Unknown"; @@ -60,7 +60,7 @@ class ResizeLinear1DCpuKernelMod : public NativeCpuKernelMod, public MatchKernel template CoordinateTransformationFunc ChooseCoordinateTransformationFunc( - MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) const; + CoordinateTransformMode coordinate_transformation_mode) const; template void ComputeInterpolationCaches(const size_t out_size, const size_t in_size, @@ -73,8 +73,7 @@ class ResizeLinear1DCpuKernelMod : public NativeCpuKernelMod, public MatchKernel size_t channel_{0}; size_t in_width_{0}; size_t out_width_{0}; - MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode_ = - MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS; + CoordinateTransformMode coordinate_transformation_mode_ = CoordinateTransformMode::ALIGN_CORNERS; }; } // namespace mindspore::kernel diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.cc index 4ee61b47dc4..acff1096dcd 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.cc @@ -116,10 +116,10 @@ const std::vector ResizeLinear1DGradCpuKernelMod::CoordinateTransformationFunc ResizeLinear1DGradCpuKernelMod::ChooseCoordinateTransformationFunc( - MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) { - const std::unordered_map> coordinate_map{ - {MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS, AlignCornersFunc()}, - {MsPyEnum::CoordinateTransformationMode::HALF_PIXEL, HalfPixelFunc()}}; + CoordinateTransformMode coordinate_transformation_mode) { + const std::unordered_map> coordinate_map{ + {CoordinateTransformMode::ALIGN_CORNERS, AlignCornersFunc()}, + {CoordinateTransformMode::HALF_PIXEL, HalfPixelFunc()}}; return coordinate_map.at(coordinate_transformation_mode); } @@ -164,9 +164,9 @@ int ResizeLinear1DGradCpuKernelMod::Resize(const std::vector &in input_width_ = LongToSize(shape_[kIndex2]); coordinate_transformation_mode_ = - static_cast(inputs.at(kIndex2)->GetValueWithCheck()); - if (coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS && - coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::HALF_PIXEL) { + static_cast(inputs.at(kIndex2)->GetValueWithCheck()); + if (coordinate_transformation_mode_ != CoordinateTransformMode::ALIGN_CORNERS && + coordinate_transformation_mode_ != CoordinateTransformMode::HALF_PIXEL) { MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now."; } SetWorkSpaceSize(inputs); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.h index 6ef212d40b2..55d9083c4dc 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/resize_linear_1d_grad_cpu_kernel.h @@ -24,7 +24,7 @@ #include #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" -#include "ops/auto_generate/gen_enum_def.h" +#include "mindapi/base/types.h" namespace mindspore::kernel { constexpr auto kUnknown = "Unknown"; @@ -66,7 +66,7 @@ class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod, template CoordinateTransformationFunc ChooseCoordinateTransformationFunc( - MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode); + CoordinateTransformMode coordinate_transformation_mode); std::string kernel_type_{kUnknown}; bool align_corners_{false}; @@ -76,8 +76,7 @@ class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod, size_t channel_{0}; size_t input_width_{0}; size_t output_width_{0}; - MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode_ = - MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS; + CoordinateTransformMode coordinate_transformation_mode_ = CoordinateTransformMode::ALIGN_CORNERS; }; } // namespace mindspore::kernel diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh index 9ff27c562e5..3a699b7fc09 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh @@ -19,12 +19,18 @@ #include #include #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "mindapi/base/types.h" enum class ReductionMode { kNone, kMean, kSum }; static std::map kReductionModeMap{ {"none", ReductionMode::kNone}, {"mean", ReductionMode::kMean}, {"sum", ReductionMode::kSum}}; +static std::map kEnumReductionModeMap{ + {static_cast(mindspore::Reduction::NONE), ReductionMode::kNone}, + {static_cast(mindspore::Reduction::MEAN), ReductionMode::kMean}, + {static_cast(mindspore::Reduction::REDUCTION_SUM), ReductionMode::kSum}}; + template CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, const T *weight, T *loss, diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.cc index aea48922387..4aefc38bf45 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.cc @@ -55,21 +55,21 @@ enum class FFTNormMode { by_root_n, // same as above, but sqrt the product }; -FFTNormMode GetNormModeFromString(const MsPyEnum::NormMode &norm_type, const bool is_inverse) { - if (norm_type == MsPyEnum::NormMode::FORWARD) { +FFTNormMode GetNormModeFromString(const ops::NormMode &norm_type, const bool is_inverse) { + if (norm_type == ops::NormMode::FORWARD) { return is_inverse ? FFTNormMode::none : FFTNormMode::by_n; } - if (norm_type == MsPyEnum::NormMode::BACKWARD) { + if (norm_type == ops::NormMode::BACKWARD) { return is_inverse ? FFTNormMode::by_n : FFTNormMode::none; } - if (norm_type == MsPyEnum::NormMode::ORTHO) { + if (norm_type == ops::NormMode::ORTHO) { return FFTNormMode::by_root_n; } MS_LOG(ERROR) << "For 'FFTWithSize', the fft norm type " << norm_type << " is unsupported!"; return FFTNormMode::none; } -double GetNormScale(const MsPyEnum::NormMode &norm_type, const bool is_inverse, const int n) { +double GetNormScale(const ops::NormMode &norm_type, const bool is_inverse, const int n) { FFTNormMode norm_mode = GetNormModeFromString(norm_type, is_inverse); if (norm_mode == FFTNormMode::none) { return 1.0; @@ -115,7 +115,7 @@ bool FFTWithSizeGpuKernelMod::FFTVarietyInResize(const std::vectorGetValueWithCheck(); is_inverse_ = inputs[kIndex2]->GetValueWithCheck(); is_real_ = inputs[kIndex3]->GetValueWithCheck(); - norm_type_ = static_cast(inputs[kIndex4]->GetValueWithCheck()); + norm_type_ = static_cast(inputs[kIndex4]->GetValueWithCheck()); is_onesided_ = inputs[kIndex5]->GetValueWithCheck(); auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.h index ab161b6248c..c0cb4826e32 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/fft_with_size_gpu_kernel.h @@ -26,7 +26,7 @@ #include #include #include -#include "ops/auto_generate/gen_enum_def.h" +#include "ops/op_enum.h" #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fft_with_size_impl.cuh" @@ -89,7 +89,7 @@ class FFTWithSizeGpuKernelMod : public NativeGpuKernelMod { int64_t rank_{1}; bool is_inverse_{false}; // 0: forward, 1: inverse bool is_real_{false}; - MsPyEnum::NormMode norm_type_{MsPyEnum::NormMode::BACKWARD}; // forward, backward, ortho + ops::NormMode norm_type_{ops::NormMode::BACKWARD}; // forward, backward, ortho // is_onesided controls whether frequency is halved when signal is real, which means is_real_ is true. // The default value is true. cufft does not support full freq with real signal. We use cast as a temporary solution. bool is_onesided_{true}; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h index 8b8ccf9638a..be87748553e 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_gpu_kernel.h @@ -26,7 +26,7 @@ #include #include "mindspore/core/ops/ops_func_impl/grid_sampler_2d.h" #include "mindspore/core/ops/ops_func_impl/grid_sampler_3d.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" +#include "mindspore/core/ops/op_enum.h" #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh" diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_grad_gpu_kernel.h index fce3fac5245..84e05271546 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_grad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/grid_sampler_grad_gpu_kernel.h @@ -24,7 +24,7 @@ #include #include "mindspore/core/ops/ops_func_impl/grid_sampler_2d_grad.h" #include "mindspore/core/ops/ops_func_impl/grid_sampler_3d_grad.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" +#include "mindspore/core/ops/op_enum.h" #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_grad_impl.cuh" diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_gpu_kernel.cc index efd3c48aaee..b8573a98dac 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_gpu_kernel.cc @@ -42,7 +42,7 @@ int NLLLossGpuKernelMod::Resize(const std::vector &inputs, const return ret; } auto reduction = inputs[kReductionIdx]->GetValueWithCheck(); - reduction_ = static_cast(reduction); + reduction_ = kEnumReductionModeMap[reduction]; ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck(); auto logits_shape = inputs[kIndex0]->GetShapeVector(); label_size_ = logits_shape[0]; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_grad_gpu_kernel.cc index b4783cd8f18..aaa1e4f7b07 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/nll_loss_grad_gpu_kernel.cc @@ -45,7 +45,7 @@ int NLLLossGradGpuKernelMod::Resize(const std::vector &inputs, return ret; } auto reduction = inputs[kReductionIdx]->GetValueWithCheck(); - reduction_ = static_cast(reduction); + reduction_ = kEnumReductionModeMap[reduction]; ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck(); auto logits_shape = inputs[kIndex0]->GetShapeVector(); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_gpu_kernel.cc index 5bd43d5a9d3..dfc175cd6d3 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_gpu_kernel.cc @@ -16,7 +16,7 @@ #include "plugin/device/gpu/kernel/nn/resize_linear_1d_gpu_kernel.h" #include "mindspore/core/abstract/utils.h" -#include "ops/auto_generate/gen_enum_def.h" +#include "mindapi/base/types.h" namespace { constexpr const size_t kResizeLinear1DInputsNum = 3; @@ -58,10 +58,9 @@ int ResizeLinear1DGpuKernelMod::Resize(const std::vector &inputs out_width_ = output_shape_[kIndex2]; auto coordinate_transformation_mode = inputs.at(kIndex2)->GetValueWithCheck(); - if (coordinate_transformation_mode == static_cast(MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS)) { + if (coordinate_transformation_mode == static_cast(CoordinateTransformMode::ALIGN_CORNERS)) { mode_ = ResizeLinearCoordinateTransformationMode::ALIGN_CORNERS; - } else if (coordinate_transformation_mode == - static_cast(MsPyEnum::CoordinateTransformationMode::HALF_PIXEL)) { + } else if (coordinate_transformation_mode == static_cast(CoordinateTransformMode::HALF_PIXEL)) { mode_ = ResizeLinearCoordinateTransformationMode::HALF_PIXEL; } else { MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now."; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_grad_gpu_kernel.cc index 957ae55fd53..e4e09f38b37 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_linear_1d_grad_gpu_kernel.cc @@ -17,7 +17,7 @@ #include "plugin/device/gpu/kernel/nn/resize_linear_1d_grad_gpu_kernel.h" #include "mindspore/core/abstract/utils.h" #include "ops/ops_func_impl/resize_linear_1d_grad.h" -#include "ops/auto_generate/gen_enum_def.h" +#include "mindapi/base/types.h" namespace { constexpr const size_t kResizeLinear1DGradInputsNum = 3; @@ -64,10 +64,9 @@ int ResizeLinear1DGradGpuKernelMod::Resize(const std::vector &in workspace_size_list_.push_back(work_space_size * sizeof(float)); auto coordinate_transformation_mode = inputs.at(kIndex2)->GetValueWithCheck(); - if (coordinate_transformation_mode == static_cast(MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS)) { + if (coordinate_transformation_mode == static_cast(CoordinateTransformMode::ALIGN_CORNERS)) { mode_ = ResizeLinearCoordinateTransformationMode::ALIGN_CORNERS; - } else if (coordinate_transformation_mode == - static_cast(MsPyEnum::CoordinateTransformationMode::HALF_PIXEL)) { + } else if (coordinate_transformation_mode == static_cast(CoordinateTransformMode::HALF_PIXEL)) { mode_ = ResizeLinearCoordinateTransformationMode::HALF_PIXEL; } else { MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now."; diff --git a/mindspore/ccsrc/plugin/device/gpu/optimizer/bce_with_logits_loss_fusion.cc b/mindspore/ccsrc/plugin/device/gpu/optimizer/bce_with_logits_loss_fusion.cc index 53b6631ce6b..f9b1ef7af21 100644 --- a/mindspore/ccsrc/plugin/device/gpu/optimizer/bce_with_logits_loss_fusion.cc +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/bce_with_logits_loss_fusion.cc @@ -20,7 +20,6 @@ #include #include "mindspore/core/ops/nn_ops.h" #include "mindspore/core/ops/math_ops.h" -#include "mindspore/core/ops/auto_generate/gen_enum_def.h" #include "mindspore/core/ops/op_utils.h" #include "include/backend/anf_runtime_algorithm.h" #include "include/common/utils/anfalgo.h" diff --git a/mindspore/ccsrc/pybind_api/utils/op_enum_py.cc b/mindspore/ccsrc/pybind_api/utils/op_enum_py.cc new file mode 100644 index 00000000000..33fbc994cfc --- /dev/null +++ b/mindspore/ccsrc/pybind_api/utils/op_enum_py.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mindspore/core/ops/op_enum.h" +#include "mindspore/core/ops/op_def.h" +#include "mindspore/core/mindapi/base/format.h" +#include "include/common/pybind_api/api_register.h" + +namespace mindspore::ops { +void RegOpEnum(py::module *m) { + auto m_sub = m->def_submodule("op_enum", "submodule for op enum"); + (void)m_sub.def("str_to_enum", &StringToEnumImpl, "string to enum value"); + (void)py::enum_(*m, "OpDtype", py::arithmetic()) + .value("DT_BEGIN", OP_DTYPE::DT_BEGIN) + .value("DT_BOOL", OP_DTYPE::DT_BOOL) + .value("DT_INT", OP_DTYPE::DT_INT) + .value("DT_FLOAT", OP_DTYPE::DT_FLOAT) + .value("DT_NUMBER", OP_DTYPE::DT_NUMBER) + .value("DT_TENSOR", OP_DTYPE::DT_TENSOR) + .value("DT_STR", OP_DTYPE::DT_STR) + .value("DT_ANY", OP_DTYPE::DT_ANY) + .value("DT_TUPLE_BOOL", OP_DTYPE::DT_TUPLE_BOOL) + .value("DT_TUPLE_INT", OP_DTYPE::DT_TUPLE_INT) + .value("DT_TUPLE_FLOAT", OP_DTYPE::DT_TUPLE_FLOAT) + .value("DT_TUPLE_NUMBER", OP_DTYPE::DT_TUPLE_NUMBER) + .value("DT_TUPLE_TENSOR", OP_DTYPE::DT_TUPLE_TENSOR) + .value("DT_TUPLE_STR", OP_DTYPE::DT_TUPLE_STR) + .value("DT_TUPLE_ANY", OP_DTYPE::DT_TUPLE_ANY) + .value("DT_LIST_BOOL", OP_DTYPE::DT_LIST_BOOL) + .value("DT_LIST_INT", OP_DTYPE::DT_LIST_INT) + .value("DT_LIST_FLOAT", OP_DTYPE::DT_LIST_FLOAT) + .value("DT_LIST_NUMBER", OP_DTYPE::DT_LIST_NUMBER) + .value("DT_LIST_TENSOR", OP_DTYPE::DT_LIST_TENSOR) + .value("DT_LIST_STR", OP_DTYPE::DT_LIST_STR) + .value("DT_LIST_ANY", OP_DTYPE::DT_LIST_ANY) + .value("DT_TYPE", OP_DTYPE::DT_TYPE) + .value("DT_END", OP_DTYPE::DT_END); + // There are currently some deficiencies in format, which will be filled in later. + (void)py::enum_(*m, "FormatEnum", py::arithmetic()) + .value("DEFAULT_FORMAT", Format::DEFAULT_FORMAT) + .value("NCHW", Format::NCHW) + .value("NHWC", Format::NHWC) + .value("NHWC4", Format::NHWC4) + .value("HWKC", Format::HWKC) + .value("HWCK", Format::HWCK) + .value("KCHW", Format::KCHW) + .value("CKHW", Format::CKHW) + .value("KHWC", Format::KHWC) + .value("CHWK", Format::CHWK) + .value("HW", Format::HW) + .value("HW4", Format::HW4) + .value("NC", Format::NC) + .value("NC4", Format::NC4) + .value("NC4HW4", Format::NC4HW4) + .value("NCDHW", Format::NCDHW) + .value("NWC", Format::NWC) + .value("NCW", Format::NCW) + .value("NDHWC", Format::NDHWC) + .value("NC8HW8", Format::NC8HW8); +} +} // namespace mindspore::ops diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h index fd6cf66eca5..1806f9ee5bf 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h @@ -238,7 +238,7 @@ class GEPadMod { class GEReduction { public: static std::string ConvertEnumToString(int64_t id) { - static const std::vector reductions = {"none", "mean", "sum", "add"}; + static const std::vector reductions = {"sum", "mean", "none"}; if (id < 0 || id >= static_cast(reductions.size())) { MS_LOG(EXCEPTION) << "Invalid reduction " << id; return ""; diff --git a/mindspore/core/mindapi/base/format.h b/mindspore/core/mindapi/base/format.h index 07a0f7d14a0..7666c23d3a2 100644 --- a/mindspore/core/mindapi/base/format.h +++ b/mindspore/core/mindapi/base/format.h @@ -86,7 +86,7 @@ enum Format : int64_t { NUM_OF_FORMAT }; -inline std::string FormatEnumToString(mindspore::Format format) { +inline const std::vector &GetFormatNames() { static std::vector names = { "NCHW", "NHWC", @@ -148,6 +148,11 @@ inline std::string FormatEnumToString(mindspore::Format format) { "NYUV_A", "NCL", }; + return names; +} + +inline std::string FormatEnumToString(mindspore::Format format) { + const auto &names = GetFormatNames(); if (format == mindspore::Format::DEFAULT_FORMAT) { return "DefaultFormat"; } diff --git a/mindspore/core/mindapi/base/type_id.h b/mindspore/core/mindapi/base/type_id.h index b3e960b5edd..460ec7ca888 100644 --- a/mindspore/core/mindapi/base/type_id.h +++ b/mindspore/core/mindapi/base/type_id.h @@ -16,8 +16,6 @@ #ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ #define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_ -// TODO(chenfei): remove the TypeId by include the auto generated type id. -// #include "mindspore/core/ops/auto_generate/gen_enum_def.h" namespace mindspore { /// \brief TypeId defines data type identifiers. diff --git a/mindspore/core/ops/op_enum.cc b/mindspore/core/ops/op_enum.cc new file mode 100644 index 00000000000..9054a13d625 --- /dev/null +++ b/mindspore/core/ops/op_enum.cc @@ -0,0 +1,164 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/op_enum.h" + +#include +#include + +#include "mindapi/base/types.h" +#include "utils/check_convert_utils.h" +#include "mindapi/base/format.h" + +namespace mindspore { +namespace ops { + +namespace { +using StrToEnumMap = std::unordered_map; + +class RegStringToEnumHelper { + public: + template + std::string AddValues(T &&string_to_enum) { + for (const auto &kv : string_to_enum) { + if (string_to_enum_.find(kv.first) != string_to_enum_.end()) { + MS_LOG_EXCEPTION << kv.first << " has been registered!"; + } + } + string_to_enum_.merge(std::move(string_to_enum)); + return ""; + } + + const StrToEnumMap &GetValues() { return string_to_enum_; } + + private: + StrToEnumMap string_to_enum_; +}; +RegStringToEnumHelper reg_string_to_enum_helper; + +#define REG_STRING_TO_ENUM(enum_type, ...) \ + const auto op_enum_##enum_type = reg_string_to_enum_helper.AddValues(__VA_ARGS__); + +// Convert to uppercase uniformly +inline std::string StrToUpper(const std::string &str) { + auto res = str; + for (auto &c : res) { + c = std::toupper(c); + } + return res; +} + +// Format +inline std::unordered_map GetStringToFormatMap() { + const auto &names = GetFormatNames(); + std::unordered_map map{{"DEFAULT_FORMAT", static_cast(Format::DEFAULT_FORMAT)}}; + for (size_t i = 0; i < names.size(); ++i) { + map[StrToUpper(names[i])] = static_cast(i); + } + return map; +} +REG_STRING_TO_ENUM(format, GetStringToFormatMap()) + +// PadMode +StrToEnumMap StrToPadModeMap = {{"PAD", PadMode::PAD}, {"SAME", PadMode::SAME}, {"VALID", PadMode::VALID}}; +REG_STRING_TO_ENUM(pad_mode, StrToPadModeMap) + +// Reduction +StrToEnumMap StrToReductionMap = { + {"SUM", Reduction::REDUCTION_SUM}, {"MEAN", Reduction::MEAN}, {"NONE", Reduction::NONE}}; +REG_STRING_TO_ENUM(reduction, StrToReductionMap) + +// Activation +StrToEnumMap StrToActivationMap = {{"NO_ACTIVATION", ActivationType::NO_ACTIVATION}, + {"RELU", ActivationType::RELU}, + {"SIGMOID", ActivationType::SIGMOID}, + {"RELU6", ActivationType::RELU6}, + {"ELU", ActivationType::ELU}, + {"LEAKY_RELU", ActivationType::LEAKY_RELU}, + {"ABS", ActivationType::ABS}, + {"RELU1", ActivationType::RELU1}, + {"SOFTSIGN", ActivationType::SOFTSIGN}, + {"SOFTPLUS", ActivationType::SOFTPLUS}, + {"TANH", ActivationType::TANH}, + {"SELU", ActivationType::SELU}, + {"HSWISH", ActivationType::HSWISH}, + {"HSIGMOID", ActivationType::HSIGMOID}, + {"THRESHOLDRELU", ActivationType::THRESHOLDRELU}, + {"LINEAR", ActivationType::LINEAR}, + {"HARD_TANH", ActivationType::HARD_TANH}, + {"SIGN", ActivationType::SIGN}, + {"SWISH", ActivationType::SWISH}, + {"GELU", ActivationType::GELU}, + {"GLU", ActivationType::GLU}, + {"UNKNOWN", ActivationType::UNKNOWN}}; +REG_STRING_TO_ENUM(activation, StrToActivationMap) + +// GateOrder +REG_STRING_TO_ENUM(gate_order, StrToEnumMap{{"RZH", GateOrderMode::RZH}, {"ZRH", GateOrderMode::ZRH}}) + +// CoordinateTransformationMode +StrToEnumMap StrToCoordinateTransformationModeMap = {{"ASYMMETRIC", CoordinateTransformMode::ASYMMETRIC}, + {"ALIGN_CORNERS", CoordinateTransformMode::ALIGN_CORNERS}, + {"HALF_PIXEL", CoordinateTransformMode::HALF_PIXEL}, + {"CROP_AND_RESIZE", CoordinateTransformMode::CROP_AND_RESIZE}}; +REG_STRING_TO_ENUM(coordinate_transformation_mode, StrToCoordinateTransformationModeMap) + +// PaddingMode +StrToEnumMap StrToPaddingModeMap = {{"CONSTANT", PaddingMode::CONSTANT}, + {"REFLECT", PaddingMode::REFLECT}, + {"SYMMETRIC", PaddingMode::SYMMETRIC}, + {"MODE_RESERVED", PaddingMode::MODE_RESERVED}}; +REG_STRING_TO_ENUM(padding_mode, StrToPaddingModeMap) + +// Direction +REG_STRING_TO_ENUM(direction, StrToEnumMap{{"UNIDIRECTIONAL", Direction::UNIDIRECTIONAL}}) + +// CellType +REG_STRING_TO_ENUM(cell_type, StrToEnumMap{{"LSTM", CellType::LSTM}}) + +// Group +REG_STRING_TO_ENUM(group, StrToEnumMap{{"SYNC_BN_GROUP0", Group::SYNC_BN_GROUP0}}) + +// InterpolationMode +REG_STRING_TO_ENUM(interpolation_mode, + StrToEnumMap{{"BILINEAR", InterpolationMode::BILINEAR}, {"NEAREST", InterpolationMode::NEAREST}}) + +// NormMode +StrToEnumMap StrToNormModeMap = { + {"BACKWARD", NormMode::BACKWARD}, {"FORWARD", NormMode::FORWARD}, {"ORTHO", NormMode::ORTHO}}; +REG_STRING_TO_ENUM(norm_mode, StrToNormModeMap) + +// GridSamplerPaddingMode +StrToEnumMap StrToGridSamplerPaddingMode = {{"ZEROS", GridSamplerPaddingMode::ZEROS}, + {"BORDER", GridSamplerPaddingMode::BORDER}, + {"REFLECTION", GridSamplerPaddingMode::REFLECTION}}; +REG_STRING_TO_ENUM(grid_sampler_padding_mode, StrToGridSamplerPaddingMode) + +// KVCacheAlignMode +REG_STRING_TO_ENUM(k_v_cache_align_mode, + StrToEnumMap{{"LEFT", KVCacheAlignMode::LEFT}, {"RIGHT", KVCacheAlignMode::RIGHT}}) + +} // namespace + +int64_t StringToEnumImpl(const std::string &enum_string) { + const auto &string_to_enum_map = reg_string_to_enum_helper.GetValues(); + const auto enum_val_iter = string_to_enum_map.find(StrToUpper(enum_string)); + if (enum_val_iter == string_to_enum_map.end()) { + MS_LOG_EXCEPTION << "Can not find '" << enum_string << "', please add it"; + } + return enum_val_iter->second; +} +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/op_enum.h b/mindspore/core/ops/op_enum.h new file mode 100644 index 00000000000..ee9cb162ce6 --- /dev/null +++ b/mindspore/core/ops/op_enum.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OP_ENUM_H_ +#define MINDSPORE_CORE_OPS_OP_ENUM_H_ +#include +#include +#include +#include + +#include "mindapi/base/macros.h" + +namespace mindspore { +namespace ops { +MIND_API int64_t StringToEnumImpl(const std::string &enum_string); + +// Only the current mindspore/core/mindapi/base/types.h and other files do not have +// corresponding enumerations and then add new enumerations. +// The `enum` is used here instead of `enum class` because the current backend enum is +// represented by `int`. The `enum` is more convenient than `enum class` compare with int. +enum Direction : int64_t { UNIDIRECTIONAL = 0 }; + +enum CellType : int64_t { LSTM = 0 }; + +enum Group : int64_t { SYNC_BN_GROUP0 = 0 }; + +enum InterpolationMode : int64_t { BILINEAR = 0, NEAREST = 1 }; + +enum NormMode : int64_t { BACKWARD = 0, FORWARD = 1, ORTHO = 2 }; + +enum GridSamplerPaddingMode : int64_t { ZEROS = 0, BORDER = 1, REFLECTION = 2 }; + +enum KVCacheAlignMode : int64_t { RIGHT = 0, LEFT = 1 }; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_ENUM_H_ diff --git a/mindspore/core/ops/ops_def/argmax_op.yaml b/mindspore/core/ops/ops_def/argmax_op.yaml index ddf8bd4e05b..02f71e53bf6 100644 --- a/mindspore/core/ops/ops_def/argmax_op.yaml +++ b/mindspore/core/ops/ops_def/argmax_op.yaml @@ -8,10 +8,10 @@ argmax: default: -1 prim_init: True output_type: - dtype: int + dtype: TypeId default: mstype.int32 prim_init: True - arg_handler: dtype_to_enum + arg_handler: dtype_to_type_id returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/argmin_op.yaml b/mindspore/core/ops/ops_def/argmin_op.yaml index 020058f9831..a28bce009df 100644 --- a/mindspore/core/ops/ops_def/argmin_op.yaml +++ b/mindspore/core/ops/ops_def/argmin_op.yaml @@ -8,10 +8,10 @@ argmin: default: -1 prim_init: True output_type: - dtype: int + dtype: TypeId default: mstype.int32 prim_init: True - arg_handler: dtype_to_enum + arg_handler: dtype_to_type_id returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/avg_pool_grad_op.yaml b/mindspore/core/ops/ops_def/avg_pool_grad_op.yaml index 2a8ce486de3..cff9d351ee1 100644 --- a/mindspore/core/ops/ops_def/avg_pool_grad_op.yaml +++ b/mindspore/core/ops/ops_def/avg_pool_grad_op.yaml @@ -21,12 +21,12 @@ avg_pool_grad: dtype: int default: "'VALID'" prim_init: True - arg_handler: pad_mode_to_enum + arg_handler: str_to_enum data_format: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/avg_pool_op.yaml b/mindspore/core/ops/ops_def/avg_pool_op.yaml index 199e0589149..e5af17978aa 100644 --- a/mindspore/core/ops/ops_def/avg_pool_op.yaml +++ b/mindspore/core/ops/ops_def/avg_pool_op.yaml @@ -17,12 +17,12 @@ avg_pool: dtype: int default: "'VALID'" prim_init: True - arg_handler: pad_mode_to_enum + arg_handler: str_to_enum data_format: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/batch_norm_grad_grad_op.yaml b/mindspore/core/ops/ops_def/batch_norm_grad_grad_op.yaml index 97f9c22b39d..8734f50feae 100644 --- a/mindspore/core/ops/ops_def/batch_norm_grad_grad_op.yaml +++ b/mindspore/core/ops/ops_def/batch_norm_grad_grad_op.yaml @@ -28,7 +28,7 @@ batch_norm_grad_grad: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: dx: dtype: tensor diff --git a/mindspore/core/ops/ops_def/batch_norm_grad_op.yaml b/mindspore/core/ops/ops_def/batch_norm_grad_op.yaml index 6cb926cc0aa..6fe163ead62 100644 --- a/mindspore/core/ops/ops_def/batch_norm_grad_op.yaml +++ b/mindspore/core/ops/ops_def/batch_norm_grad_op.yaml @@ -24,7 +24,7 @@ batch_norm_grad: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: dx: dtype: tensor diff --git a/mindspore/core/ops/ops_def/batch_norm_grad_with_activation_op.yaml b/mindspore/core/ops/ops_def/batch_norm_grad_with_activation_op.yaml index fb44408df8f..27ec380fc67 100644 --- a/mindspore/core/ops/ops_def/batch_norm_grad_with_activation_op.yaml +++ b/mindspore/core/ops/ops_def/batch_norm_grad_with_activation_op.yaml @@ -28,7 +28,7 @@ batch_norm_grad_with_activation: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: dx: dtype: tensor diff --git a/mindspore/core/ops/ops_def/batch_norm_grad_with_add_and_activation_op.yaml b/mindspore/core/ops/ops_def/batch_norm_grad_with_add_and_activation_op.yaml index fb44408df8f..27ec380fc67 100644 --- a/mindspore/core/ops/ops_def/batch_norm_grad_with_add_and_activation_op.yaml +++ b/mindspore/core/ops/ops_def/batch_norm_grad_with_add_and_activation_op.yaml @@ -28,7 +28,7 @@ batch_norm_grad_with_activation: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: dx: dtype: tensor diff --git a/mindspore/core/ops/ops_def/batch_norm_op.yaml b/mindspore/core/ops/ops_def/batch_norm_op.yaml index 4e6d65db3d0..283c8c874dd 100644 --- a/mindspore/core/ops/ops_def/batch_norm_op.yaml +++ b/mindspore/core/ops/ops_def/batch_norm_op.yaml @@ -26,7 +26,7 @@ batch_norm: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: output_x: dtype: tensor diff --git a/mindspore/core/ops/ops_def/batch_norm_with_activation_op.yaml b/mindspore/core/ops/ops_def/batch_norm_with_activation_op.yaml index c9bbf29225e..b2bdbf4594d 100644 --- a/mindspore/core/ops/ops_def/batch_norm_with_activation_op.yaml +++ b/mindspore/core/ops/ops_def/batch_norm_with_activation_op.yaml @@ -26,7 +26,7 @@ batch_norm_with_activation: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: output_x: dtype: tensor diff --git a/mindspore/core/ops/ops_def/batch_norm_with_add_and_activation_op.yaml b/mindspore/core/ops/ops_def/batch_norm_with_add_and_activation_op.yaml index cb2a250a699..83aaa75c064 100644 --- a/mindspore/core/ops/ops_def/batch_norm_with_add_and_activation_op.yaml +++ b/mindspore/core/ops/ops_def/batch_norm_with_add_and_activation_op.yaml @@ -28,7 +28,7 @@ batch_norm_with_add_and_activation: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: output_x: dtype: tensor diff --git a/mindspore/core/ops/ops_def/bias_add_grad_op.yaml b/mindspore/core/ops/ops_def/bias_add_grad_op.yaml index 2fa5a36e71f..413480b2eb2 100644 --- a/mindspore/core/ops/ops_def/bias_add_grad_op.yaml +++ b/mindspore/core/ops/ops_def/bias_add_grad_op.yaml @@ -7,7 +7,7 @@ bias_add_grad: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/bias_add_op.yaml b/mindspore/core/ops/ops_def/bias_add_op.yaml index 9eedc2dd0d0..769f425108a 100644 --- a/mindspore/core/ops/ops_def/bias_add_op.yaml +++ b/mindspore/core/ops/ops_def/bias_add_op.yaml @@ -9,7 +9,7 @@ bias_add: dtype: int default: "'NCHW'" prim_init: True - arg_handler: py_format_to_enum + arg_handler: str_to_enum returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/eye_op.yaml b/mindspore/core/ops/ops_def/eye_op.yaml index cac388e8f0c..e4af983b5c0 100644 --- a/mindspore/core/ops/ops_def/eye_op.yaml +++ b/mindspore/core/ops/ops_def/eye_op.yaml @@ -6,8 +6,8 @@ eye: m: dtype: int dtype: - dtype: int - arg_handler: dtype_to_enum + dtype: TypeId + arg_handler: dtype_to_type_id returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/fft_with_size_op.yaml b/mindspore/core/ops/ops_def/fft_with_size_op.yaml index 32e4e48d73e..30d636da979 100644 --- a/mindspore/core/ops/ops_def/fft_with_size_op.yaml +++ b/mindspore/core/ops/ops_def/fft_with_size_op.yaml @@ -14,7 +14,7 @@ fft_with_size: prim_init: True norm: dtype: int - arg_handler: norm_mode_to_enum + arg_handler: str_to_enum default: "'backward'" prim_init: True onesided: diff --git a/mindspore/core/ops/ops_def/grid_sampler_2d_grad_op.yaml b/mindspore/core/ops/ops_def/grid_sampler_2d_grad_op.yaml index 1646d0cf0be..e1132fd6aec 100644 --- a/mindspore/core/ops/ops_def/grid_sampler_2d_grad_op.yaml +++ b/mindspore/core/ops/ops_def/grid_sampler_2d_grad_op.yaml @@ -11,12 +11,12 @@ grid_sampler_2d_grad: dtype: int default: "'bilinear'" prim_init: True - arg_handler: interpolation_mode_to_enum + arg_handler: str_to_enum padding_mode: dtype: int default: "'zeros'" prim_init: True - arg_handler: grid_sampler_padding_mode_to_enum + arg_handler: str_to_enum align_corners: dtype: bool default: False diff --git a/mindspore/core/ops/ops_def/grid_sampler_2d_op.yaml b/mindspore/core/ops/ops_def/grid_sampler_2d_op.yaml index 864c5b682a5..38bb4900abd 100644 --- a/mindspore/core/ops/ops_def/grid_sampler_2d_op.yaml +++ b/mindspore/core/ops/ops_def/grid_sampler_2d_op.yaml @@ -9,12 +9,12 @@ grid_sampler_2d: dtype: int default: "'bilinear'" prim_init: True - arg_handler: interpolation_mode_to_enum + arg_handler: str_to_enum padding_mode: dtype: int default: "'zeros'" prim_init: True - arg_handler: grid_sampler_padding_mode_to_enum + arg_handler: str_to_enum align_corners: dtype: bool default: False diff --git a/mindspore/core/ops/ops_def/grid_sampler_3d_grad_op.yaml b/mindspore/core/ops/ops_def/grid_sampler_3d_grad_op.yaml index fd43ee705c2..926a060d162 100644 --- a/mindspore/core/ops/ops_def/grid_sampler_3d_grad_op.yaml +++ b/mindspore/core/ops/ops_def/grid_sampler_3d_grad_op.yaml @@ -11,12 +11,12 @@ grid_sampler_3d_grad: dtype: int default: "'bilinear'" prim_init: True - arg_handler: interpolation_mode_to_enum + arg_handler: str_to_enum padding_mode: dtype: int default: "'zeros'" prim_init: True - arg_handler: grid_sampler_padding_mode_to_enum + arg_handler: str_to_enum align_corners: dtype: bool default: False diff --git a/mindspore/core/ops/ops_def/grid_sampler_3d_op.yaml b/mindspore/core/ops/ops_def/grid_sampler_3d_op.yaml index 2096b39d6d3..5e532ab7cf3 100644 --- a/mindspore/core/ops/ops_def/grid_sampler_3d_op.yaml +++ b/mindspore/core/ops/ops_def/grid_sampler_3d_op.yaml @@ -9,18 +9,18 @@ grid_sampler_3d: dtype: int default: "'bilinear'" prim_init: True - arg_handler: interpolation_mode_to_enum + arg_handler: str_to_enum padding_mode: dtype: int default: "'zeros'" prim_init: True - arg_handler: grid_sampler_padding_mode_to_enum + arg_handler: str_to_enum align_corners: dtype: bool default: False prim_init: True returns: - output: + output: dtype: tensor class: name: GridSampler3D diff --git a/mindspore/core/ops/ops_def/inner/extract_image_patches_op.yaml b/mindspore/core/ops/ops_def/inner/extract_image_patches_op.yaml index bd8b47b34f1..44e3d6b3c12 100644 --- a/mindspore/core/ops/ops_def/inner/extract_image_patches_op.yaml +++ b/mindspore/core/ops/ops_def/inner/extract_image_patches_op.yaml @@ -22,7 +22,7 @@ extract_image_patches: dtype: int default: "'VALID'" prim_init: True - arg_handler: pad_mode_to_enum + arg_handler: str_to_enum returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/inner/prompt_k_v_cache_op.yaml b/mindspore/core/ops/ops_def/inner/prompt_k_v_cache_op.yaml index e15c2b2d761..d0787639a98 100644 --- a/mindspore/core/ops/ops_def/inner/prompt_k_v_cache_op.yaml +++ b/mindspore/core/ops/ops_def/inner/prompt_k_v_cache_op.yaml @@ -19,7 +19,7 @@ prompt_k_v_cache: dtype: int default: 0 prim_init: True - arg_handler: k_v_cache_align_mode_to_enum + arg_handler: str_to_enum labels: side_effect_mem: True returns: diff --git a/mindspore/core/ops/ops_def/nllloss_grad_op.yaml b/mindspore/core/ops/ops_def/nllloss_grad_op.yaml index 6d740daa042..0de2dee2f77 100644 --- a/mindspore/core/ops/ops_def/nllloss_grad_op.yaml +++ b/mindspore/core/ops/ops_def/nllloss_grad_op.yaml @@ -15,7 +15,7 @@ nllloss_grad: dtype: int default: "'mean'" prim_init: True - arg_handler: reduction_to_enum + arg_handler: str_to_enum ignore_index: dtype: int default: -100 diff --git a/mindspore/core/ops/ops_def/nllloss_op.yaml b/mindspore/core/ops/ops_def/nllloss_op.yaml index 8710b6702c5..a3b97e6375c 100644 --- a/mindspore/core/ops/ops_def/nllloss_op.yaml +++ b/mindspore/core/ops/ops_def/nllloss_op.yaml @@ -11,7 +11,7 @@ nllloss: dtype: int default: "'mean'" prim_init: True - arg_handler: reduction_to_enum + arg_handler: str_to_enum ignore_index: dtype: int default: -100 diff --git a/mindspore/core/ops/ops_def/randperm_v2_op.yaml b/mindspore/core/ops/ops_def/randperm_v2_op.yaml index 633e4ed8f56..8c5fdf0d8a7 100644 --- a/mindspore/core/ops/ops_def/randperm_v2_op.yaml +++ b/mindspore/core/ops/ops_def/randperm_v2_op.yaml @@ -13,10 +13,10 @@ randperm_v2: default: 0 prim_init: True dtype: - dtype: int + dtype: TypeId default: mstype.int64 prim_init: True - arg_handler: dtype_to_enum + arg_handler: dtype_to_type_id returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/resize_linear_1d_grad_op.yaml b/mindspore/core/ops/ops_def/resize_linear_1d_grad_op.yaml index 947287a07eb..b8dd0dbce94 100644 --- a/mindspore/core/ops/ops_def/resize_linear_1d_grad_op.yaml +++ b/mindspore/core/ops/ops_def/resize_linear_1d_grad_op.yaml @@ -9,7 +9,7 @@ resize_linear_1d_grad: dtype: int default: "'align_corners'" prim_init: True - arg_handler: coordinate_transformation_mode_to_enum + arg_handler: str_to_enum returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_def/resize_linear_1d_op.yaml b/mindspore/core/ops/ops_def/resize_linear_1d_op.yaml index 45af2b7992a..a1b5e758cb0 100644 --- a/mindspore/core/ops/ops_def/resize_linear_1d_op.yaml +++ b/mindspore/core/ops/ops_def/resize_linear_1d_op.yaml @@ -10,7 +10,7 @@ resize_linear_1d: dtype: int default: "'align_corners'" prim_init: True - arg_handler: coordinate_transformation_mode_to_enum + arg_handler: str_to_enum returns: output: dtype: tensor diff --git a/mindspore/core/ops/ops_func_impl/eig.cc b/mindspore/core/ops/ops_func_impl/eig.cc index ba9eb06b2ef..67448ab4775 100644 --- a/mindspore/core/ops/ops_func_impl/eig.cc +++ b/mindspore/core/ops/ops_func_impl/eig.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace ops { namespace { -constexpr size_t kDefaultRank = 2; +// As kDefaultRank will have a confilict problem during lite compile, use kDefaultShapeSize. +constexpr size_t kDefaultShapeSize = 2; constexpr size_t kRowIndex = 2; constexpr size_t kColIndex = 1; } // namespace @@ -32,7 +33,7 @@ void EigCheckShapeValid(const ShapeVector &input_shape) { return; } - if (input_shape.size() < kDefaultRank) { + if (input_shape.size() < kDefaultShapeSize) { MS_EXCEPTION(ValueError) << "For Eig, x should be at lease 2" << ", but got a " << input_shape.size() << "-D Tensor."; } @@ -52,7 +53,7 @@ void EigCheckShapeValid(const ShapeVector &input_shape) { BaseShapePtr EigFuncImpl::InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const { auto input_shape = input_args[kInputIndex0]->GetShape()->GetShapeVector(); - std::vector shapes_list(kDefaultRank); + std::vector shapes_list(kDefaultShapeSize); EigCheckShapeValid(input_shape); /* infer eigen_value shape */ diff --git a/mindspore/core/ops/ops_func_impl/fft_with_size.cc b/mindspore/core/ops/ops_func_impl/fft_with_size.cc index fb5b743ca91..76b2d20e879 100644 --- a/mindspore/core/ops/ops_func_impl/fft_with_size.cc +++ b/mindspore/core/ops/ops_func_impl/fft_with_size.cc @@ -20,7 +20,6 @@ #include #include "ops/op_utils.h" #include "utils/check_convert_utils.h" -#include "ops/auto_generate/gen_enum_def.h" namespace mindspore { namespace ops { diff --git a/mindspore/core/ops/ops_func_impl/nllloss.cc b/mindspore/core/ops/ops_func_impl/nllloss.cc index 9dbe13b6634..d8d6e2c9d20 100644 --- a/mindspore/core/ops/ops_func_impl/nllloss.cc +++ b/mindspore/core/ops/ops_func_impl/nllloss.cc @@ -18,7 +18,7 @@ #include #include #include "abstract/dshape.h" -#include "ops/auto_generate/gen_enum_def.h" +#include "mindapi/base/types.h" #include "ops/op_name.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" @@ -86,15 +86,15 @@ BaseShapePtr NLLLossFuncImpl::InferShape(const PrimitivePtr &primitive, std::vector{loss_shape_ptr, total_weight_shape_ptr}); } - auto reduce_value_enum = static_cast(reduction_opt.value()); - if ((reduce_value_enum == MsPyEnum::Reduction::SUM) || (reduce_value_enum == MsPyEnum::Reduction::MEAN)) { + auto reduce_value_enum = static_cast(reduction_opt.value()); + if ((reduce_value_enum == Reduction::REDUCTION_SUM) || (reduce_value_enum == Reduction::MEAN)) { // shape () means 0D tensor. auto loss_shape_ptr = std::make_shared(loss_shape); return std::make_shared( std::vector{loss_shape_ptr, total_weight_shape_ptr}); } - if (reduce_value_enum == MsPyEnum::Reduction::NONE) { + if (reduce_value_enum == Reduction::NONE) { if (logits_shape.size() == DIM_2) { loss_shape.push_back( std::max({logits_shape[kInputIndex0], labels_shape[kInputIndex0], abstract::Shape::kShapeDimAny})); diff --git a/mindspore/core/ops/other_ops.h b/mindspore/core/ops/other_ops.h index adb9b35e0fe..5f85f9552cd 100644 --- a/mindspore/core/ops/other_ops.h +++ b/mindspore/core/ops/other_ops.h @@ -103,6 +103,9 @@ GVAR_DEF(PrimitivePtr, kPrimNPUAntiQuant, std::make_shared("AscendAnt // Fusion Inference OP GVAR_DEF(PrimitivePtr, kPrimFFN, std::make_shared("FFN")); + +// ToEnum OP +GVAR_DEF(PrimitivePtr, kPrimStringToEnum, std::make_shared("StringToEnum")); } // namespace prim } // namespace mindspore diff --git a/mindspore/core/ops/string_to_enum.cc b/mindspore/core/ops/string_to_enum.cc new file mode 100644 index 00000000000..54c4e95928e --- /dev/null +++ b/mindspore/core/ops/string_to_enum.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/string_to_enum.h" + +#include +#include +#include + +#include "abstract/ops/primitive_infer_map.h" +#include "ir/dtype/tensor_type.h" +#include "ir/primitive.h" +#include "utils/log_adapter.h" +#include "utils/check_convert_utils.h" +#include "ops/other_ops.h" +#include "mindapi/src/helper.h" +#include "ops/op_enum.h" + +namespace mindspore { +namespace ops { + +MIND_API_OPERATOR_IMPL(StringToEnum, BaseOperator); +class MIND_API StringToEnumInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + return std::make_shared(); + } + + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { + return kInt64; + } + + ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + const auto &op_name = primitive->name(); + (void)CheckAndConvertUtils::CheckInteger("StringToEnum infer", int64_t(input_args.size()), kEqual, 1, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + const auto &input_value = input_args[0]->GetValue(); + MS_EXCEPTION_IF_NULL(input_value); + if (!input_value->isa()) { + MS_LOG_EXCEPTION << "Currently, for " << op_name << ", input value must a string value"; + } + const auto &enum_str = GetValue(input_value); + const auto enum_int = StringToEnumImpl(enum_str); + return MakeValue(enum_int); + } +}; + +REGISTER_PRIMITIVE_OP_INFER_IMPL(StringToEnum, prim::kPrimStringToEnum, StringToEnumInfer, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/string_to_enum.h b/mindspore/core/ops/string_to_enum.h new file mode 100644 index 00000000000..4b0e34ffea9 --- /dev/null +++ b/mindspore/core/ops/string_to_enum.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_ +#define MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_ +#include + +#include "mindapi/base/types.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameStringToEnum = "StringToEnum"; + +/// \brief Returns the enum value of the input string. +class MIND_API StringToEnum : public BaseOperator { + public: + MIND_API_BASE_MEMBER(StringToEnum); + /// \brief Constructor. + StringToEnum() : BaseOperator(kNameStringToEnum) { InitIOName({"x"}, {"output"}); } + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.StringToEnum for the inputs. + void Init() const {} +}; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_ diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc index 2685712b8da..141090bb50f 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc @@ -20,7 +20,6 @@ #include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h" #include "plugin/device/cpu/kernel/nnacl/op_base.h" #include "ops/auto_generate/gen_lite_ops.h" -#include "ops/auto_generate/gen_enum_def.h" namespace mindspore { class ConvBiasFusionInoutTest : public ConvFusionInoutTest { diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py index e129b2790b1..824a5e57487 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py @@ -25,7 +25,7 @@ from mindspore.ops.primitive import _primexpr from mindspore.ops.function import _VmapGeneralRule from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \ _bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting -from mindspore.ops.auto_generate.gen_enum_def import PyFormat +from mindspore._c_expression import FormatEnum as Format @vmap_rules_getters.register(G.NLLLossGrad) @@ -306,7 +306,7 @@ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size): @vmap_rules_getters.register(G.BatchNormGradGrad) def get_batchnorm_grad_grad_vmap_rule(prim, axis_size): """VmapRule for `BatchNormGradGrad` operation.""" - NCHW = PyFormat.NCHW.value + NCHW = Format.NCHW def vmap_rule(x_bdim, dy_bdim, scale_bdim, mean_bdim, variance_bdim, dout_dx_bdim, dout_dscale_bdim, dout_dbias_bdim, is_training_bdim, epsilon_bdim, data_format_bdim): @@ -384,7 +384,7 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size): bn_min_dim = 3 bn_max_dim = 5 prim_name = prim.name - NHWC = PyFormat.NHWC.value + NHWC = Format.NHWC def vmap_rule(grad_bdim, x_bdim, scale_bdim, rsv_1_bdim, rsv_2_bdim, rsv_3_bdim, training_bdim, epsilon_bdim, format_bdim): diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py index 5eb73bf9bda..c88ecaf7592 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py @@ -29,7 +29,7 @@ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_prepr _bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \ _vmap_clone_prim, _get_reduce_batch_axis from mindspore.ops.primitive import Primitive -from mindspore.ops.auto_generate.gen_enum_def import PyFormat +from mindspore._c_expression import FormatEnum as Format @vmap_rules_getters.register(P.ApplyAdaMax) @@ -372,7 +372,7 @@ def get_bias_add_vmap_rule(prim, axis_size): @constexpr def get_channal_pos_in_x(d_format, n_dims): - if d_format == PyFormat.NHWC: + if d_format == Format.NHWC: return n_dims return 2 @@ -424,7 +424,7 @@ def get_bias_add_grad_vmap_rule(prim, axis_size): """VmapRule for `BiasAddGrad` operation.""" @constexpr def get_channal_pos(d_format, x_rank): - if d_format == PyFormat.NHWC: + if d_format == Format.NHWC: return x_rank return 2 @@ -1249,7 +1249,7 @@ def get_batchnorm_vmap_rule(prim, axis_size): bn_min_dim = 3 bn_max_dim = 5 prim_name = "BatchNorm" - NCHW = PyFormat.NCHW.value + NCHW = Format.NCHW def vmap_rule(*inputs): is_all_none, result = vmap_general_preprocess(prim, *inputs) diff --git a/mindspore/python/mindspore/ops/auto_generate/__init__.py b/mindspore/python/mindspore/ops/auto_generate/__init__.py index 58b6f9ab612..59f180fb05e 100644 --- a/mindspore/python/mindspore/ops/auto_generate/__init__.py +++ b/mindspore/python/mindspore/ops/auto_generate/__init__.py @@ -19,12 +19,11 @@ Primitive operator classes and operator functional. A collection of operators to build neural networks or to compute functions. """ -from . import gen_ops_def, gen_arg_handler, gen_enum_def, gen_arg_dtype_cast +from . import gen_ops_def, gen_arg_handler, gen_arg_dtype_cast from .gen_ops_def import * from .gen_arg_handler import * from .gen_arg_dtype_cast import * -from .gen_enum_def import * from ..operations.manually_defined.ops_def import * from .gen_inner_ops_def import * diff --git a/mindspore/python/mindspore/ops/operations/manually_defined/ops_def.py b/mindspore/python/mindspore/ops/operations/manually_defined/ops_def.py index 2f65836a01f..b63e4646e64 100644 --- a/mindspore/python/mindspore/ops/operations/manually_defined/ops_def.py +++ b/mindspore/python/mindspore/ops/operations/manually_defined/ops_def.py @@ -572,7 +572,7 @@ class BatchNorm(Primitive): self.is_training = is_training self.epsilon = epsilon self.momentum = momentum - self.data_format = handler.py_format_to_enum(data_format) + self.data_format = handler.str_to_enum(data_format) def __call__(self, *args): return super().__call__(*args, self.is_training, self.epsilon, diff --git a/mindspore/python/mindspore/ops_generate/arg_dtype_cast.py b/mindspore/python/mindspore/ops_generate/arg_dtype_cast.py index 4e7f852ba5c..ad433711fc3 100755 --- a/mindspore/python/mindspore/ops_generate/arg_dtype_cast.py +++ b/mindspore/python/mindspore/ops_generate/arg_dtype_cast.py @@ -20,7 +20,7 @@ import mindspore as ms from mindspore import ops from mindspore.common.tensor import Tensor from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple -from mindspore.ops.auto_generate.gen_enum_def import OpDtype +from mindspore._c_expression import OpDtype tensor_to_tuple_ = TensorToTuple() @@ -67,54 +67,53 @@ def tuple_to_tensor(data): def list_to_tensor(data): return ops.tuple_to_array(list_to_tuple(data)) +# There will be some problems in using OpDtype.xxx directly in GRAPH_MODE, so convert it to int. # type -PY_DT_TYPE = OpDtype.PY_DT_TYPE.value - +DT_TYPE_VAL = int(OpDtype.DT_TYPE) # scalar -PY_DT_INT = OpDtype.PY_DT_INT.value -PY_DT_FLOAT = OpDtype.PY_DT_FLOAT.value -PY_DT_BOOL = OpDtype.PY_DT_BOOL.value -PY_DT_NUMBER = OpDtype.PY_DT_NUMBER.value +DT_INT_VAL = int(OpDtype.DT_INT) +DT_FLOAT_VAL = int(OpDtype.DT_FLOAT) +DT_BOOL_VAL = int(OpDtype.DT_BOOL) +DT_NUMBER_VAL = int(OpDtype.DT_NUMBER) # tuple -PY_DT_TUPLE_BOOL = OpDtype.PY_DT_TUPLE_BOOL.value -PY_DT_TUPLE_INT = OpDtype.PY_DT_TUPLE_INT.value -PY_DT_TUPLE_FLOAT = OpDtype.PY_DT_TUPLE_FLOAT.value -PY_DT_TUPLE_NUMBER = OpDtype.PY_DT_TUPLE_NUMBER.value -PY_DT_TUPLE_TENSOR = OpDtype.PY_DT_TUPLE_TENSOR.value -PY_DT_TUPLE_STR = OpDtype.PY_DT_TUPLE_STR.value -PY_DT_TUPLE_ANY = OpDtype.PY_DT_TUPLE_ANY.value +DT_TUPLE_BOOL_VAL = int(OpDtype.DT_TUPLE_BOOL) +DT_TUPLE_INT_VAL = int(OpDtype.DT_TUPLE_INT) +DT_TUPLE_FLOAT_VAL = int(OpDtype.DT_TUPLE_FLOAT) +DT_TUPLE_NUMBER_VAL = int(OpDtype.DT_TUPLE_NUMBER) +DT_TUPLE_TENSOR_VAL = int(OpDtype.DT_TUPLE_TENSOR) +DT_TUPLE_STR_VAL = int(OpDtype.DT_TUPLE_STR) +DT_TUPLE_ANY_VAL = int(OpDtype.DT_TUPLE_ANY) # list -PY_DT_LIST_BOOL = OpDtype.PY_DT_LIST_BOOL.value -PY_DT_LIST_INT = OpDtype.PY_DT_LIST_INT.value -PY_DT_LIST_FLOAT = OpDtype.PY_DT_LIST_FLOAT.value -PY_DT_LIST_NUMBER = OpDtype.PY_DT_LIST_NUMBER.value -PY_DT_LIST_TENSOR = OpDtype.PY_DT_LIST_TENSOR.value -PY_DT_LIST_STR = OpDtype.PY_DT_LIST_STR.value -PY_DT_LIST_ANY = OpDtype.PY_DT_LIST_ANY.value +DT_LIST_BOOL_VAL = int(OpDtype.DT_LIST_BOOL) +DT_LIST_INT_VAL = int(OpDtype.DT_LIST_INT) +DT_LIST_FLOAT_VAL = int(OpDtype.DT_LIST_FLOAT) +DT_LIST_NUMBER_VAL = int(OpDtype.DT_LIST_NUMBER) +DT_LIST_TENSOR_VAL = int(OpDtype.DT_LIST_TENSOR) +DT_LIST_STR_VAL = int(OpDtype.DT_LIST_STR) +DT_LIST_ANY_VAL = int(OpDtype.DT_LIST_ANY) # tensor -PY_DT_TENSOR = OpDtype.PY_DT_TENSOR.value - +DT_TENSOR_VAL = int(OpDtype.DT_TENSOR) dtype_to_string = { - PY_DT_INT: "int", - PY_DT_FLOAT: "float", - PY_DT_BOOL: "bool", - PY_DT_NUMBER: "number", - PY_DT_TENSOR: "Tensor", - PY_DT_TUPLE_BOOL: "tuple of bool", - PY_DT_TUPLE_INT: "tuple of int", - PY_DT_TUPLE_FLOAT: "tuple of float", - PY_DT_TUPLE_NUMBER: "tuple of number", - PY_DT_TUPLE_TENSOR: "tuple of tensor", - PY_DT_TUPLE_STR: "tuple of string", - PY_DT_TUPLE_ANY: "tuple of Any", - PY_DT_LIST_BOOL: "list of bool", - PY_DT_LIST_INT: "list of int", - PY_DT_LIST_FLOAT: "list of float", - PY_DT_LIST_NUMBER: "list of number", - PY_DT_LIST_TENSOR: "list of Tensor", - PY_DT_LIST_STR: "list of string", - PY_DT_LIST_ANY: "list of Any" + DT_INT_VAL: "int", + DT_FLOAT_VAL: "float", + DT_BOOL_VAL: "bool", + DT_NUMBER_VAL: "number", + DT_TENSOR_VAL: "Tensor", + DT_TUPLE_BOOL_VAL: "tuple of bool", + DT_TUPLE_INT_VAL: "tuple of int", + DT_TUPLE_FLOAT_VAL: "tuple of float", + DT_TUPLE_NUMBER_VAL: "tuple of number", + DT_TUPLE_TENSOR_VAL: "tuple of tensor", + DT_TUPLE_STR_VAL: "tuple of string", + DT_TUPLE_ANY_VAL: "tuple of Any", + DT_LIST_BOOL_VAL: "list of bool", + DT_LIST_INT_VAL: "list of int", + DT_LIST_FLOAT_VAL: "list of float", + DT_LIST_NUMBER_VAL: "list of number", + DT_LIST_TENSOR_VAL: "list of Tensor", + DT_LIST_STR_VAL: "list of string", + DT_LIST_ANY_VAL: "list of Any" } @@ -122,34 +121,35 @@ def is_tuple(type_id): """ Check type id is tuple. """ - return type_id in (PY_DT_TUPLE_BOOL, PY_DT_TUPLE_INT, PY_DT_TUPLE_FLOAT, PY_DT_TUPLE_NUMBER, - PY_DT_TUPLE_TENSOR, PY_DT_TUPLE_STR, PY_DT_TUPLE_ANY) + return type_id in (DT_TUPLE_BOOL_VAL, DT_TUPLE_INT_VAL, DT_TUPLE_FLOAT_VAL, DT_TUPLE_NUMBER_VAL, + DT_TUPLE_TENSOR_VAL, DT_TUPLE_STR_VAL, DT_TUPLE_ANY_VAL) def is_list(type_id): """ Check type id is list. """ - return type_id in (PY_DT_LIST_BOOL, PY_DT_LIST_INT, PY_DT_LIST_FLOAT, PY_DT_LIST_NUMBER, PY_DT_LIST_TENSOR, - PY_DT_LIST_STR, PY_DT_LIST_ANY) + return type_id in (DT_LIST_BOOL_VAL, DT_LIST_INT_VAL, DT_LIST_FLOAT_VAL, DT_LIST_NUMBER_VAL, + DT_LIST_TENSOR_VAL, + DT_LIST_STR_VAL, DT_LIST_ANY_VAL) def is_number(type_id): """ Check type id is number. """ - return type_id in (PY_DT_INT, PY_DT_FLOAT, PY_DT_BOOL, PY_DT_NUMBER) + return type_id in (DT_INT_VAL, DT_FLOAT_VAL, DT_BOOL_VAL, DT_NUMBER_VAL) def is_instance_of(data, type_id): """ Instead isinstance(obj, type). """ - if type_id == PY_DT_INT: + if type_id == DT_INT_VAL: return isinstance(data, int) - if type_id == PY_DT_FLOAT: + if type_id == DT_FLOAT_VAL: return isinstance(data, float) - if type_id == PY_DT_BOOL: + if type_id == DT_BOOL_VAL: return isinstance(data, bool) if is_number(type_id): return isinstance(data, (int, float, bool)) @@ -157,7 +157,7 @@ def is_instance_of(data, type_id): return isinstance(data, tuple) if is_list(type_id): return isinstance(data, list) - if type_id == PY_DT_TENSOR: + if type_id == DT_TENSOR_VAL: return isinstance(data, Tensor) return False @@ -192,7 +192,7 @@ def do_type_cast(data, dst_type): """Type conversion.""" if is_instance_of(data, dst_type): return data - if dst_type == PY_DT_FLOAT: + if dst_type == DT_FLOAT_VAL: if isinstance(data, int): return int_to_float(data) elif is_tuple(dst_type): @@ -202,7 +202,7 @@ def do_type_cast(data, dst_type): return list_to_tuple(data) if isinstance(data, Tensor): return tensor_to_tuple(data) - elif dst_type == PY_DT_TENSOR: + elif dst_type == DT_TENSOR_VAL: if isinstance(data, (int, float, bool)): return scalar_to_tensor(data) if isinstance(data, tuple): @@ -211,7 +211,7 @@ def do_type_cast(data, dst_type): return list_to_tensor(data) elif is_number(dst_type): if isinstance(data, Tensor): - if dst_type == PY_DT_INT: + if dst_type == DT_INT_VAL: data = ops.cast(data, ms.int64) ret = TensorToScalar()(data) return ret @@ -222,6 +222,11 @@ def type_it(data, src_type, dst_type): """ cast operator argument data type. """ + if not isinstance(src_type, tuple): + src_type = int(src_type) + else: + src_type = tuple((int(t) for t in src_type)) + dst_type = int(dst_type) if not is_instance_in(data, src_type) and not is_instance_of(data, dst_type): support_list = get_support_dtype_list(src_type, dst_type) raise TypeError(f"For type conversion here, only support <{support_list}>, but got {type(data)}.") diff --git a/mindspore/python/mindspore/ops_generate/arg_handler.py b/mindspore/python/mindspore/ops_generate/arg_handler.py index 748813b95e1..b916d66f033 100644 --- a/mindspore/python/mindspore/ops_generate/arg_handler.py +++ b/mindspore/python/mindspore/ops_generate/arg_handler.py @@ -14,10 +14,7 @@ # ============================================================================ """Operator argument handle function.""" -from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum - -ops_dtype_to_enum = DtypeToEnum() - +from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum def to_kernel_size(kernel_size): """ @@ -131,9 +128,9 @@ def to_3d_paddings(pad): return pad raise ValueError(f"For arg 'pad', the value is invalid: {pad}.") +dtype_to_type_id = DtypeToEnum() -def dtype_to_enum(dtype): - """ - convert mindspore.dtype to enum. - """ - return ops_dtype_to_enum(dtype) +# string to enum +# A function for converting str type to enum type are written here, +# but the backend supports str input, and converting str input to enum input is not necessary. +str_to_enum = StringToEnum() diff --git a/mindspore/python/mindspore/ops_generate/enum.yaml b/mindspore/python/mindspore/ops_generate/enum.yaml deleted file mode 100644 index b4f6beb8023..00000000000 --- a/mindspore/python/mindspore/ops_generate/enum.yaml +++ /dev/null @@ -1,241 +0,0 @@ - -# #enum TypeId -# type_id: -# kTypeUnknown: 0 -# #Meta types. -# kMetaTypeBegin: 0 -# kMetaTypeType: 1 -# kMetaTypeAny: 2 -# kMetaTypeObject: 3 -# kMetaTypeTypeType: 4 -# kMetaTypeProblem: 5 -# kMetaTypeExternal: 6 -# kMetaTypeNone: 7 -# kMetaTypeNull: 8 -# kMetaTypeEllipsis: 9 -# kMetaTypeEnd: 10 -# # -# # Object types -# # -# kObjectTypeBegin: 10 -# kObjectTypeNumber: 11 -# kObjectTypeString: 12 -# kObjectTypeList: 13 -# kObjectTypeTuple: 14 -# kObjectTypeSlice: 15 -# kObjectTypeKeyword: 16 -# kObjectTypeTensorType: 17 -# kObjectTypeRowTensorType: 18 -# kObjectTypeCOOTensorType: 19 -# kObjectTypeUndeterminedType: 20 -# kObjectTypeClass: 21 -# kObjectTypeDictionary: 22 -# kObjectTypeFunction: 23 -# kObjectTypeJTagged: 24 -# kObjectTypeSymbolicKeyType: 25 -# kObjectTypeEnvType: 26 -# kObjectTypeRefKey: 27 -# kObjectTypeRef: 28 -# kObjectTypeEnd: 29 -# # -# # Number Types -# # -# kNumberTypeBegin: 29 -# kNumberTypeBool: 30 -# kNumberTypeInt: 31 -# kNumberTypeInt8: 32 -# kNumberTypeInt16: 33 -# kNumberTypeInt32: 34 -# kNumberTypeInt64: 35 -# kNumberTypeUInt: 36 -# kNumberTypeUInt8: 37 -# kNumberTypeUInt16: 38 -# kNumberTypeUInt32: 39 -# kNumberTypeUInt64: 40 -# kNumberTypeFloat: 41 -# kNumberTypeFloat16: 42 -# kNumberTypeFloat32: 43 -# kNumberTypeFloat64: 44 -# kNumberTypeBFloat16: 45 -# kNumberTypeDouble: 46 -# kNumberTypeComplex: 47 -# kNumberTypeComplex64: 48 -# kNumberTypeComplex128: 49 -# kNumberTypeInt4: 50 -# kNumberTypeGLUInt: 51 -# kNumberTypeEnd: 52 -# # -# # Monad Types -# # -# kMonadTypeBegin: 52 -# kObjectTypeMonad: 53 -# kObjectTypeUMonad: 54 -# kObjectTypeIOMonad: 55 -# kMonadTypeEnd: 56 -# # -# # Sparse Types -# # -# kSparseTypeBegin: 56 -# kObjectTypeCSRTensorType: 57 -# kObjectTypeSparseTensorType: 58 -# kObjectTypeMapTensorType: 59 -# kSparseTypeEnd: 60 - -# #enum OpDtype -op_dtype: - PY_DT_BEGIN: 0 - PY_DT_BOOL: 1 - PY_DT_INT: 2 - PY_DT_FLOAT: 3 - PY_DT_NUMBER: 4 - PY_DT_TENSOR: 5 - PY_DT_STR: 6 - PY_DT_ANY: 7 - PY_DT_TUPLE_BOOL: 8 - PY_DT_TUPLE_INT: 9 - PY_DT_TUPLE_FLOAT: 10 - PY_DT_TUPLE_NUMBER: 11 - PY_DT_TUPLE_TENSOR: 12 - PY_DT_TUPLE_STR: 13 - PY_DT_TUPLE_ANY: 14 - PY_DT_LIST_BOOL: 15 - PY_DT_LIST_INT: 16 - PY_DT_LIST_FLOAT: 17 - PY_DT_LIST_NUMBER: 18 - PY_DT_LIST_TENSOR: 19 - PY_DT_LIST_STR: 20 - PY_DT_LIST_ANY: 21 - PY_DT_TYPE: 22 - PY_DT_END: 23 - -# enum Format -# the value must be consistent with enum defined in `mindspore/core/mindapi/base/format.h` -py_format: - NCHW: 0 - NHWC: 1 - NHWC4: 2 - HWKC: 3 - HWCK: 4 - KCHW: 5 - CKHW: 6 - KHWC: 7 - CHWK: 8 - HW: 9 - HW4: 10 - NC: 11 - NC4: 12 - NC4HW4: 13 - NCDHW: 14 - NWC: 15 - NCW: 16 - NDHWC: 17 - NC8HW8: 18 - FRACTAL_NZ: 19 - ND: 20 - NC1HWC0: 21 - FRACTAL_Z: 22 - NC1C0HWPAD: 23 - NHWC1C0: 24 - FSR_NCHW: 25 - FRACTAL_DECONV: 26 - C1HWNC0: 27 - FRACTAL_DECONV_TRANSPOSE: 28 - FRACTAL_DECONV_SP_STRIDE_TRANS: 29 - NC1HWC0_C04: 30 - FRACTAL_Z_C04: 31 - CHWN: 32 - FRACTAL_DECONV_SP_STRIDE8_TRANS: 33 - HWCN: 34 - NC1KHKWHWC0: 35 - BN_WEIGHT: 36 - FILTER_HWCK: 37 - LOOKUP_LOOKUPS: 38 - LOOKUP_KEYS: 39 - LOOKUP_VALUE: 40 - LOOKUP_OUTPUT: 41 - LOOKUP_HITS: 42 - C1HWNCoC0: 43 - MD: 44 - FRACTAL_ZZ: 45 - DHWCN: 46 - NDC1HWC0: 47 - FRACTAL_Z_3D: 48 - CN: 49 - DHWNC: 50 - FRACTAL_Z_3D_TRANSPOSE: 51 - FRACTAL_ZN_LSTM: 52 - FRACTAL_Z_G: 53 - ND_RNN_BIAS: 54 - FRACTAL_ZN_RNN: 55 - NYUV: 56 - NYUV_A: 57 - NCL: 58 - -# enum PadMode -pad_mode: - PAD: 0 - SAME: 1 - VALID: 2 - -# enum Reduction -reduction: - NONE: 0 - MEAN: 1 - SUM: 2 - ADD: 3 - -# enum Direction -direction: - UNIDIRECTIONAL: 0 - -# enum Activation -activation: - TANH: 0 - -# enum GateOrder -gate_order: - RZH: 0 - ZRH: 1 - -# enum CellType -cell_type: - LSTM: 0 - -# enum Group -group: - SYNC_BN_GROUP0: 0 - -# enum InterpolationMode -interpolation_mode: - BILINEAR: 0 - NEAREST: 1 - -# enum NormMode -norm_mode: - BACKWARD: 0 - FORWARD: 1 - ORTHO: 2 - -# enum GridSamplerPaddingMode -grid_sampler_padding_mode: - ZEROS: 0 - BORDER: 1 - REFLECTION: 2 - -# enum PadFillingMode -pad_filling_mode: - CONSTANT: 0 - REFLECT: 1 - EDGE: 2 - CIRCULAR: 3 - SYMMETRIC: 4 - -# enum CoordinateTransformationMode -coordinate_transformation_mode: - ALIGN_CORNERS: 0 - HALF_PIXEL: 1 - -# enum KVCacheAlignMode -k_v_cache_align_mode: - RIGHT: 0 - LEFT: 1 diff --git a/mindspore/python/mindspore/ops_generate/gen_ops.py b/mindspore/python/mindspore/ops_generate/gen_ops.py index d28b8002e88..8a57f2b96cb 100644 --- a/mindspore/python/mindspore/ops_generate/gen_ops.py +++ b/mindspore/python/mindspore/ops_generate/gen_ops.py @@ -20,7 +20,8 @@ import re import shutil import pathlib import gen_utils -from gen_utils import py_licence_str, cc_license_str, check_change_and_replace_file, merge_files, safe_load_yaml +from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files, + safe_load_yaml, convert_dtype_str) from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes import template from template import CppTemplate @@ -353,6 +354,12 @@ def {func_name}({', '.join(arg for arg in func_args)}): return gen_py +def get_dtype(arg_info): + dtype = arg_info.get('dtype') + # Currently, TypeId is represented by int + if dtype == 'TypeId': + dtype = 'int' + return dtype def process_args(args): """ @@ -365,7 +372,7 @@ def process_args(args): init_args_with_default = [] args_handlers = {} for arg_name, arg_info in args.items(): - dtype = arg_info.get('dtype') + dtype = get_dtype(arg_info) default_value = arg_info.get('default') has_default = 'default' in arg_info.keys() is_optional = arg_info.get('default') == "None" if has_default else False @@ -621,7 +628,7 @@ namespace mindspore::ops { if not is_prim_init: continue - dtype = arg_info.get('dtype') + dtype = get_dtype(arg_info) if dtype == "str": dtype = "std::string" if dtype == "tuple[int]": @@ -671,8 +678,8 @@ std::unordered_map gOpDefTable = {{""" for i, (arg_name, arg_info) in enumerate(args.items()): args_dict[arg_name] = i cc_index_str += f"""{{"{arg_name}", {i}}},\n""" - dtype = arg_info.get('dtype') - cc_dtype_str = 'DT_' + dtype.replace('[', '_').replace(']', '').upper() + dtype = get_dtype(arg_info) + cc_dtype_str = convert_dtype_str(dtype) is_prim_init = 1 if arg_info.get('prim_init') else 0 arg_handler = arg_info.get('arg_handler') @@ -723,7 +730,7 @@ from mindspore.common._decorator import deprecated from mindspore.ops._primitive_cache import _get_cache_prim from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it from mindspore.ops.auto_generate.gen_arg_handler import * -from mindspore.ops.auto_generate.gen_enum_def import OpDtype +from mindspore._c_expression import OpDtype from mindspore.common._stub_tensor import _convert_stub """ @@ -891,67 +898,10 @@ def generate_aclnn_reg_file(work_path, yaml_str): check_change_and_replace_file(register_file, tmp_register_file) -eum_py_header = f""" -\"\"\"Operator argument enum definition.\"\"\" - -from enum import Enum -""" - -eum_cc_header = f""" -#ifndef MINDSPORE_CORE_OPS_GEN_ENUM_DEF_ -#define MINDSPORE_CORE_OPS_GEN_ENUM_DEF_ - -#include - -namespace mindspore::MsPyEnum {{ -""" - -eum_cc_end = f"""}} // namespace mindspore::MsPyEnum -#endif // MINDSPORE_CORE_OPS_GEN_ENUM_DEF_ -""" - - -def generate_enum_code(yaml_data): +def generate_arg_handler_files(work_path): """ - Generate python and c++ enum definition + Generate arg handler files. """ - gen_eum_py_func = '' - gen_eum_py_def = eum_py_header - gen_eum_cc_def = eum_cc_header - for enum_name, enum_data in yaml_data.items(): - class_name = ''.join(word.capitalize() for word in enum_name.split('_')) - gen_eum_py_func += f"""\n -def {enum_name}_to_enum({enum_name}_str): - \""" - convert {enum_name} string to enum. - \""" - if not isinstance({enum_name}_str, str): - raise TypeError(f"The {enum_name} should be string, but got {{{enum_name}_str}}") - {enum_name}_str = {enum_name}_str.upper()\n""" - gen_eum_py_def += f"""\n\nclass {class_name}(Enum):\n""" - gen_eum_cc_def += f"""enum {class_name} : int64_t {{\n""" - - for enum_key, enum_value in enum_data.items(): - gen_eum_py_func += f""" if {enum_name}_str == "{enum_key}": - return {enum_value}\n""" - gen_eum_py_def += f""" {enum_key} = {enum_value}\n""" - gen_eum_cc_def += f""" {enum_key} = {enum_value},\n""" - - gen_eum_py_func += f""" raise ValueError(f"Invalid {class_name}: {{{enum_name}_str}}")\n""" - gen_eum_cc_def += f"""}};\n\n""" - gen_eum_cc_def += eum_cc_end - - return gen_eum_py_func, gen_eum_py_def, gen_eum_cc_def - - -def generate_enum_files(work_path): - """ - Generate python function and c++ definition from enum yaml. - """ - enum_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/enum.yaml') - yaml_str = safe_load_yaml(enum_yaml_path) - py_enum_func, py_enum_def, cc_enum_def = generate_enum_code(yaml_str) - dst_dir = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate') src_arg_handler_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_handler.py') dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py') @@ -959,8 +909,6 @@ def generate_enum_files(work_path): if not os.path.exists(dst_dir): os.makedirs(dst_dir) shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path) - with open(tmp_dst_arg_handler_path, 'a') as py_file: - py_file.write(py_enum_func) check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path) src_arg_dtype_cast_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_dtype_cast.py') @@ -969,18 +917,6 @@ def generate_enum_files(work_path): shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path) check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path) - enum_def_py_path = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate/gen_enum_def.py') - tmp_enum_def_py_path = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate/tmp_gen_enum_def.py') - with open(tmp_enum_def_py_path, 'w') as cc_file: - cc_file.write(py_licence_str + py_enum_def) - check_change_and_replace_file(enum_def_py_path, tmp_enum_def_py_path) - - enum_def_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_enum_def.h') - tmp_enum_def_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_enum_def.h') - with open(tmp_enum_def_cc_path, 'w') as cc_file: - cc_file.write(cc_license_str + cc_enum_def) - check_change_and_replace_file(enum_def_cc_path, tmp_enum_def_cc_path) - def main(): current_path = os.path.dirname(os.path.abspath(__file__)) @@ -1005,8 +941,8 @@ def main(): cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/') pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True) - # generate enum code from enum.yaml - generate_enum_files(work_path) + # generate arg_handler files + generate_arg_handler_files(work_path) # generate ops python files generate_ops_py_files(work_path, safe_load_yaml(ops_yaml_path), safe_load_yaml(doc_yaml_path), "gen") diff --git a/mindspore/python/mindspore/ops_generate/gen_ops_inner_prim.py b/mindspore/python/mindspore/ops_generate/gen_ops_inner_prim.py index a9bd646633c..e8c1f2cf922 100644 --- a/mindspore/python/mindspore/ops_generate/gen_ops_inner_prim.py +++ b/mindspore/python/mindspore/ops_generate/gen_ops_inner_prim.py @@ -16,6 +16,7 @@ from mindspore.ops.primitive import Primitive, prim_attr_register from mindspore._c_expression import typing +from mindspore._c_expression import op_enum class DtypeToEnum(Primitive): @@ -41,3 +42,28 @@ class DtypeToEnum(Primitive): if not isinstance(dtype, typing.Type): raise TypeError(f"For dtype_to_enum function, the input should be mindpsore dtype, but got {dtype}.") return typing.type_to_type_id(dtype) + + +class StringToEnum(Primitive): + r""" + Convert string to enum. + + Inputs: + - **enum_str** (str) - The str data. + + Outputs: + An integer. + + Supported Platforms: + ``CPU`` + """ + + @prim_attr_register + def __init__(self): + """Initialize""" + + def __call__(self, enum_str): + """Run in PyNative mode""" + if not isinstance(enum_str, str): + raise TypeError(f"For StringToEnum op, the input should be a str, but got {type(enum_str)}.") + return op_enum.str_to_enum(enum_str) diff --git a/mindspore/python/mindspore/ops_generate/gen_utils.py b/mindspore/python/mindspore/ops_generate/gen_utils.py index 6e5faf26777..1351857ce17 100644 --- a/mindspore/python/mindspore/ops_generate/gen_utils.py +++ b/mindspore/python/mindspore/ops_generate/gen_utils.py @@ -53,30 +53,36 @@ cc_license_str = f"""/** * limitations under the License. */""" +def convert_dtype_str(dtype_str): + """ + Convert dtype str to expression in ops file + """ + return 'DT_' + dtype_str.replace('[', '_').replace(']', '').upper() + def get_type_str(type_str): """ Get the unified type str for operator arg dtype. """ # add more type here - type_kind_dict = { - 'int': 'OpDtype.PY_DT_INT', - 'float': 'OpDtype.PY_DT_FLOAT', - 'bool': 'OpDtype.PY_DT_BOOL', - 'number': 'OpDtype.PY_DT_NUMBER', - 'tuple[int]': 'OpDtype.PY_DT_TUPLE_ANY', - 'tuple[float]': 'OpDtype.PY_DT_TUPLE_ANY', - 'tuple[bool]': 'OpDtype.PY_DT_TUPLE_ANY', - 'tuple[tensor]': 'OpDtype.PY_DT_TUPLE_ANY', - 'list[int]': 'OpDtype.PY_DT_LIST_ANY', - 'list[float]': 'OpDtype.PY_DT_LIST_ANY', - 'list[bool]': 'OpDtype.PY_DT_LIST_ANY', - 'list[tensor]': 'OpDtype.PY_DT_LIST_ANY', - 'tensor': 'OpDtype.PY_DT_TENSOR', - 'type': 'OpDtype.PY_DT_TYPE', + type_kind_set = { + 'int', + 'float', + 'bool', + 'number', + 'tuple[int]', + 'tuple[float]', + 'tuple[bool]', + 'tuple[tensor]', + 'list[int]', + 'list[float]', + 'list[bool]', + 'list[tensor]', + 'tensor', + 'type', } - if type_str in type_kind_dict: - return type_kind_dict[type_str] + if type_str in type_kind_set: + return "OpDtype." + convert_dtype_str(type_str) raise TypeError(f"""Unsupported type {type_str} for args.""") @@ -158,10 +164,10 @@ def get_assign_str_by_type_it(arg_info, arg_name, dtype): type_cast_tuple = tuple(ct.strip() for ct in type_cast.split(",")) assign_str += f'type_it({arg_name}, ' if len(type_cast_tuple) == 1: - assign_str += get_type_str(type_cast_tuple[0]) + '.value, ' + assign_str += get_type_str(type_cast_tuple[0]) + ', ' else: - assign_str += '(' + ', '.join(get_type_str(ct) + '.value' for ct in type_cast_tuple) + '), ' - assign_str += get_type_str(dtype) + '.value)' + assign_str += '(' + ', '.join(get_type_str(ct) for ct in type_cast_tuple) + '), ' + assign_str += get_type_str(dtype) + ')' else: assign_str = arg_name return assign_str diff --git a/tests/ut/cpp/ops/test_ops_fft_with_size.cc b/tests/ut/cpp/ops/test_ops_fft_with_size.cc index 7993de5277b..677eec3f363 100644 --- a/tests/ut/cpp/ops/test_ops_fft_with_size.cc +++ b/tests/ut/cpp/ops/test_ops_fft_with_size.cc @@ -16,7 +16,7 @@ #include "ops/test_value_utils.h" #include "ops/auto_generate/gen_ops_name.h" -#include "ops/auto_generate/gen_enum_def.h" +#include "ops/op_enum.h" #include "ops/test_ops.h" #include "ops/ops_func_impl/fft_with_size.h" @@ -42,7 +42,7 @@ TEST_P(TestFFT, dyn_shape) { auto signal_ndim = param.signal_ndim == -1 ? Any->ToAbstract() : CreatePyInt(param.signal_ndim)->ToAbstract(); auto inverse = CreateScalar(param.inverse)->ToAbstract(); auto real = CreateScalar(param.real)->ToAbstract(); - auto norm = CreateScalar(static_cast(MsPyEnum::NormMode::BACKWARD))->ToAbstract(); + auto norm = CreateScalar(static_cast(ops::NormMode::BACKWARD))->ToAbstract(); auto onesided = CreateScalar(param.onesided)->ToAbstract(); auto signal_sizes = param.signal_sizes->ToAbstract(); diff --git a/tests/ut/cpp/ops/test_ops_nllloss.cc b/tests/ut/cpp/ops/test_ops_nllloss.cc index 11b6c8f9c50..00fc12aa7b7 100644 --- a/tests/ut/cpp/ops/test_ops_nllloss.cc +++ b/tests/ut/cpp/ops/test_ops_nllloss.cc @@ -29,6 +29,7 @@ #include "ops/test_value_utils.h" #include "utils/ms_context.h" #include "utils/tensor_construct_utils.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { @@ -99,31 +100,238 @@ TEST_P(TestNLLLoss, dyn_shape) { INSTANTIATE_TEST_CASE_P( TestNLLLossGroup, TestNLLLoss, testing::Values( - NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 0, true, {{2}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}}, + NLLLossParams{{-1, 3}, + kFloat32, + {-1}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::NONE), + true, + {{-1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-1, 3}, + kFloat32, + {-1}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::MEAN), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-1, 3}, + kFloat32, + {-1}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::REDUCTION_SUM), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-1, 3}, + kFloat32, + {-2}, + kInt32, + {-2}, + kFloat32, + static_cast(Reduction::NONE), + true, + {{-1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-1, 3}, + kFloat32, + {-2}, + kInt32, + {-2}, + kFloat32, + static_cast(Reduction::MEAN), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-1, 3}, + kFloat32, + {-2}, + kInt32, + {-2}, + kFloat32, + static_cast(Reduction::REDUCTION_SUM), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-2}, + kFloat32, + {-2}, + kInt32, + {-2}, + kFloat32, + static_cast(Reduction::NONE), + true, + {{-1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-2}, + kFloat32, + {-2}, + kInt32, + {-2}, + kFloat32, + static_cast(Reduction::MEAN), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-2}, + kFloat32, + {-2}, + kInt32, + {-2}, + kFloat32, + static_cast(Reduction::REDUCTION_SUM), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, + kFloat32, + {2}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::NONE), + true, + {{2}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, + kFloat32, + {2}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::MEAN), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, + kFloat32, + {2}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::REDUCTION_SUM), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}}, + NLLLossParams{{-1, -1}, + kFloat32, + {-1}, + kInt32, + {-1}, + kFloat32, + static_cast(Reduction::NONE), + true, + {{-1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-1, -1}, + kFloat32, + {-1}, + kInt32, + {-1}, + kFloat32, + static_cast(Reduction::MEAN), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-1, -1}, + kFloat32, + {-1}, + kInt32, + {-1}, + kFloat32, + static_cast(Reduction::REDUCTION_SUM), + true, + {{}, {}}, + {kFloat32, kFloat32}}, NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}}, + NLLLossParams{{-2}, + kFloat32, + {-1}, + kInt32, + {-1}, + kFloat32, + static_cast(Reduction::NONE), + true, + {{-1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-2}, + kFloat32, + {-1}, + kInt32, + {-1}, + kFloat32, + static_cast(Reduction::MEAN), + true, + {{}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{-2}, + kFloat32, + {-1}, + kInt32, + {-1}, + kFloat32, + static_cast(Reduction::REDUCTION_SUM), + true, + {{}, {}}, + {kFloat32, kFloat32}}, NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3, 4}, kFloat32, {2}, kInt32, {3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {2, 3}, kInt32, {3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {2, 3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {2}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}}, - NLLLossParams{{2, 3}, kFloat32, {3}, kInt32, {2}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}})); + NLLLossParams{{2, 3, 4}, + kFloat32, + {2}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::NONE), + false, + {{1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, + kFloat32, + {2, 3}, + kInt32, + {3}, + kFloat32, + static_cast(Reduction::NONE), + false, + {{1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, + kFloat32, + {2}, + kInt32, + {2, 3}, + kFloat32, + static_cast(Reduction::NONE), + false, + {{1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, + kFloat32, + {2}, + kInt32, + {2}, + kFloat32, + static_cast(Reduction::NONE), + false, + {{1}, {}}, + {kFloat32, kFloat32}}, + NLLLossParams{{2, 3}, + kFloat32, + {3}, + kInt32, + {2}, + kFloat32, + static_cast(Reduction::NONE), + false, + {{1}, {}}, + {kFloat32, kFloat32}})); } // namespace ops } // namespace mindspore