refractor some python enums in yaml to C++

This commit is contained in:
hanhuifeng2020 2023-12-07 14:58:39 +08:00
parent d8301d1dba
commit 60e4c1183f
78 changed files with 913 additions and 573 deletions

2
.gitignore vendored
View File

@ -119,7 +119,6 @@ mindspore/core/ops/auto_generate/gen_lite_ops.h
mindspore/core/ops/auto_generate/gen_lite_ops.cc
mindspore/core/ops/auto_generate/gen_ops_name.h
mindspore/core/ops/auto_generate/gen_ops_primitive.h
mindspore/core/ops/auto_generate/gen_enum_def.h
mindspore/core/ops/auto_generate/gen_ops_def.cc
mindspore/core/ops/auto_generate/gen_ops_def.h
mindspore/python/mindspore/ops_generate/ops.yaml
@ -129,7 +128,6 @@ mindspore/python/mindspore/ops_generate/inner_ops_doc.yaml
mindspore/python/mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py
mindspore/python/mindspore/ops/auto_generate/gen_arg_handler.py
mindspore/python/mindspore/ops/auto_generate/gen_arg_dtype_cast.py
mindspore/python/mindspore/ops/auto_generate/gen_enum_def.py
mindspore/python/mindspore/ops/auto_generate/gen_inner_ops_def.py
mindspore/python/mindspore/ops/auto_generate/gen_ops_def.py
mindspore/python/mindspore/ops/auto_generate/gen_pyboost_func.py

View File

@ -107,8 +107,8 @@ using ArgHandlerFunc = std::function<ValuePtr(const ValuePtr &)>;
ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) {
static const std::unordered_map<std::string, ArgHandlerFunc> arg_handler_funcs = {
{"py_format_to_enum", EnumToFormat},
{"dtype_to_enum", EnumToDtype},
{"str_to_enum", EnumToFormat},
{"dtype_to_type_id", EnumToDtype},
};
if (arg_handler_funcs.find(arg_handler) != arg_handler_funcs.end()) {
return arg_handler_funcs.at(arg_handler);
@ -119,8 +119,8 @@ ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) {
ArgHandlerFunc GetOppArgHandlerFunc(const std::string &arg_handler) {
static const std::unordered_map<std::string, ArgHandlerFunc> opp_arg_handler_funcs = {
{"py_format_to_enum", FormatToEnum},
{"dtype_to_enum", DtypeToEnum},
{"str_to_enum", FormatToEnum},
{"dtype_to_type_id", DtypeToEnum},
};
if (opp_arg_handler_funcs.find(arg_handler) != opp_arg_handler_funcs.end()) {
return opp_arg_handler_funcs.at(arg_handler);

View File

@ -70,6 +70,10 @@ namespace abstract {
void RegPrimitiveFrontEval();
}
#endif
namespace ops {
void RegOpEnum(py::module *m);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_PYBIND_API_API_REGISTER_H_

View File

@ -156,6 +156,7 @@ void RegModule(py::module *m) {
mindspore::abstract::RegPrimitiveFrontEval();
#endif
mindspore::expander::RegPackExpanderPy(m);
mindspore::ops::RegOpEnum(m);
}
void RegModuleHelper(py::module *m) {

View File

@ -137,6 +137,11 @@ AnfNodePtr GetNodeAfterArgHandler(const AnfNodePtr &node, const ops::OpInputArg
return node;
}
const auto arg_handler_func = prim::GetPythonOps(op_arg.arg_handler_, parse::PYTHON_MOD_PRIMITIVE_ARG_HANDLER_MODULE);
if (arg_handler_func->isa<Primitive>()) {
auto arg_handler_fg = dyn_cast<Primitive>(arg_handler_func);
MS_EXCEPTION_IF_NULL(arg_handler_fg);
return fg->NewCNodeInOrder({NewValueNode(arg_handler_fg), node});
}
auto arg_handler_fg = dyn_cast<FuncGraph>(arg_handler_func);
MS_EXCEPTION_IF_NULL(arg_handler_fg);
arg_handler_fg->set_manager(fg->manager());

View File

@ -16,7 +16,7 @@
#include "plugin/device/cpu/kernel/grid_sampler_2d_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindspore/core/ops/op_enum.h"
namespace {
const size_t kDataSizeThreshold = 64 * 1024;
@ -119,7 +119,7 @@ void GridSampler2DCpuKernelMod::ComputeTask(const T *x_addr, const T *grid_addr,
auto x_ptr_NC = out_iter[kZero] * x_stride_[kZero];
auto output_ptr_NCDHW = out_iter[kZero] * output_stride_[kZero] + out_iter[kTwo] * output_stride_[kTwo] +
out_iter[kThree] * output_stride_[kThree];
if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
int64_t x_tnw = static_cast<int64_t>(std::floor(x));
int64_t y_tnw = static_cast<int64_t>(std::floor(y));
int64_t x_tne = x_tnw + 1;
@ -153,7 +153,7 @@ void GridSampler2DCpuKernelMod::ComputeTask(const T *x_addr, const T *grid_addr,
output_addr[output_ptr_NCDHW] += x_addr[x_index] * tse;
}
}
} else if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
} else if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
int64_t x_nearest = static_cast<int64_t>(std::round(x));
int64_t y_nearest = static_cast<int64_t>(std::round(y));
for (size_t c = 0; c < out_c; c++, x_ptr_NC += x_stride_[kOne], output_ptr_NCDHW += output_stride_[kOne]) {
@ -221,9 +221,9 @@ void GridSampler2DCpuKernelMod::Call2Half(const float16 *x_data_addr, float16 *y
auto x_ptr_NC = y_iter[0] * x_stride[0];
auto y_ptr_NCHW = y_iter[0] * y_stride[0] + y_iter[2] * y_stride[2] + y_iter[3] * y_stride[3];
if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
BilinearHalf(x, y, x_data_addr, y_data_addr, y_c, x_dims, y_stride, x_stride, x_ptr_NC, y_ptr_NCHW);
} else if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
} else if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
NearestHalf(x, y, x_data_addr, y_data_addr, y_c, x_dims, y_stride, x_stride, x_ptr_NC, y_ptr_NCHW);
}
}
@ -333,9 +333,9 @@ T GridSampler2DCpuKernelMod::GridSamplerComputeSourceIndex(T coord, int64_t size
} else {
coord = ((coord + 1.f) * size - 1) / num2;
}
if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
coord = std::min(static_cast<T>(size - 1), std::max(coord, static_cast<T>(0)));
} else if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) {
} else if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::REFLECTION)) {
if (align_corners) {
coord = ReflectCoordinates(coord, 0, num2 * (size - 1));
} else {

View File

@ -15,7 +15,7 @@
*/
#include "plugin/device/cpu/kernel/grid_sampler_2d_grad_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindspore/core/ops/op_enum.h"
namespace {
const size_t kDataSizeThreshold = 64 * 1024;
@ -93,14 +93,14 @@ void GridSampler2DGradCpuKernelMod::ComputeTask(const std::vector<KernelTensor *
auto grid_data_addr = static_cast<T *>(inputs[kTwo]->device_ptr());
auto dx_data_addr = static_cast<T *>(outputs[kZero]->device_ptr());
auto dgrid_data_addr = static_cast<T *>(outputs[kOne]->device_ptr());
if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
interp = GridSamplerInterpolation::Bilinear;
} else if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
} else if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
interp = GridSamplerInterpolation::Nearest;
}
if (padding_mode_ == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::ZEROS)) {
if (padding_mode_ == static_cast<int64_t>(ops::GridSamplerPaddingMode::ZEROS)) {
padding = GridSamplerPadding::Zeros;
} else if (padding_mode_ == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
} else if (padding_mode_ == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
padding = GridSamplerPadding::Border;
} else {
padding = GridSamplerPadding::Reflection;

View File

@ -15,7 +15,7 @@
*/
#include "plugin/device/cpu/kernel/grid_sampler_3d_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindspore/core/ops/op_enum.h"
namespace {
const size_t kDataSizeThreshold = 64 * 1024;
@ -117,7 +117,7 @@ void GridSampler3DCpuKernelMod::ComputeTask(T *x_addr, T *grid_addr, T *output_a
auto x_ptr_NC = out_iter[kZero] * x_stride_[kZero];
auto output_ptr_NCDHW = out_iter[kZero] * output_stride_[kZero] + out_iter[kTwo] * output_stride_[kTwo] +
out_iter[kThree] * output_stride_[kThree] + out_iter[kFour] * output_stride_[kFour];
if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
int64_t x_tnw = static_cast<int64_t>(std::floor(x));
int64_t y_tnw = static_cast<int64_t>(std::floor(y));
int64_t z_tnw = static_cast<int64_t>(std::floor(z));
@ -176,7 +176,7 @@ void GridSampler3DCpuKernelMod::ComputeTask(T *x_addr, T *grid_addr, T *output_a
output_addr[output_ptr_NCDHW] += x_addr[x_index] * bse;
}
}
} else if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
} else if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
int64_t x_nearest = static_cast<int64_t>(std::round(x));
int64_t y_nearest = static_cast<int64_t>(std::round(y));
int64_t z_nearest = static_cast<int64_t>(std::round(z));
@ -220,9 +220,9 @@ T GridSampler3DCpuKernelMod::grid_sampler_compute_source_index(T coord, int64_t
} else {
coord = ((coord + 1.f) * size - kOne) / kTwo;
}
if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
coord = std::min(static_cast<T>(static_cast<size_t>(size) - kOne), std::max(coord, static_cast<T>(kZero)));
} else if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) {
} else if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::REFLECTION)) {
if (align_corners) {
coord = reflect_coordinates(coord, static_cast<int64_t>(kZero), kTwo * (size - kOne));
} else {

View File

@ -15,7 +15,7 @@
*/
#include "plugin/device/cpu/kernel/grid_sampler_3d_grad_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindspore/core/ops/op_enum.h"
namespace {
const size_t kDataSizeThreshold = 64 * 1024;
@ -214,7 +214,7 @@ void GridSampler3DGradCpuKernelMod::ComputeTask(T *grad_addr, T *x_addr, T *grid
x = grid_sampler_compute_source_index_set_grad(x, x_shape_[kFour], padding_mode, align_corners_, &gx_mult);
y = grid_sampler_compute_source_index_set_grad(y, x_shape_[kThree], padding_mode, align_corners_, &gy_mult);
z = grid_sampler_compute_source_index_set_grad(z, x_shape_[kTwo], padding_mode, align_corners_, &gz_mult);
if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
size_t grad_ptr_NCDHW =
n * grad_stride_[kZero] + d * grad_stride_[kTwo] + h * grad_stride_[kThree] + w * grad_stride_[kFour];
size_t dx_ptr_NC = n * dx_stride_[kZero], x_ptr_NC = n * x_stride_[kZero];
@ -223,7 +223,7 @@ void GridSampler3DGradCpuKernelMod::ComputeTask(T *grad_addr, T *x_addr, T *grid
std::vector<T> mult = {gx_mult, gy_mult, gz_mult};
std::vector<size_t> ptr = {grad_ptr_NCDHW, x_ptr_NC, dx_ptr_NC, dgrid_ptr_NDHW};
BilinearKernel<T>(addr, location, mult, ptr);
} else if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
} else if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
int64_t x_nearest = static_cast<int64_t>(std::round(x));
int64_t y_nearest = static_cast<int64_t>(std::round(y));
int64_t z_nearest = static_cast<int64_t>(std::round(z));
@ -281,10 +281,10 @@ T GridSampler3DGradCpuKernelMod::grid_sampler_compute_source_index_set_grad(T co
*grad_x = static_cast<T>(size) / kTwo;
coord = ((coord + kOne) * size - kOne) / kTwo;
}
if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
*grad_x = (*grad_x) * grad_clip;
} else if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) {
} else if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::REFLECTION)) {
if (align_corners) {
coord = reflect_coordinates_set_grad(coord, 0, (size - 1) * static_cast<int64_t>(kTwo), &grad_refl);
} else {

View File

@ -48,7 +48,7 @@ int NLLLossCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs, const
return ret;
}
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
reduction_type_ = static_cast<MsPyEnum::Reduction>(reduction);
reduction_type_ = static_cast<Reduction>(reduction);
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
auto logits_shape = inputs[kIndex0]->GetShapeVector();
@ -92,15 +92,15 @@ bool NLLLossCpuKernelMod::LaunchKernel(const std::vector<kernel::KernelTensor *>
float n_loss = -logits[index] * n_weight;
tmp_total_weight += n_weight;
total_loss += n_loss;
if (reduction_type_ == MsPyEnum::Reduction::NONE) {
if (reduction_type_ == Reduction::NONE) {
loss[i] = n_loss;
}
}
*total_weight = tmp_total_weight;
if (reduction_type_ == MsPyEnum::Reduction::SUM) {
if (reduction_type_ == Reduction::REDUCTION_SUM) {
*loss = total_loss;
} else if (reduction_type_ == MsPyEnum::Reduction::MEAN) {
} else if (reduction_type_ == Reduction::MEAN) {
*loss = total_loss / tmp_total_weight;
}
return true;

View File

@ -23,7 +23,7 @@
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "nnacl/kernel/nllloss.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace kernel {
@ -55,7 +55,7 @@ class NLLLossCpuKernelMod : public NativeCpuKernelMod {
NLLLossFunc kernel_func_;
static std::vector<std::pair<KernelAttr, NLLLossFunc>> func_list_;
NLLLossStruct nllloss_param_{};
MsPyEnum::Reduction reduction_type_;
Reduction reduction_type_;
int64_t ignore_index_;
};
} // namespace kernel

View File

@ -49,7 +49,7 @@ int NLLLossGradCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
return ret;
}
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
reduction_type_ = static_cast<MsPyEnum::Reduction>(reduction);
reduction_type_ = static_cast<Reduction>(reduction);
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
auto logits_shape = inputs[0]->GetShapeVector();
nllloss_param_.batch_ = LongToInt(logits_shape[0]);
@ -89,9 +89,9 @@ bool NLLLossGradCpuKernelMod::LaunchKernel(const std::vector<kernel::KernelTenso
}
int index = i * nllloss_param_.class_num_ + labels[i];
float n_weight = weight[labels[i]];
if (reduction_type_ == MsPyEnum::Reduction::SUM) {
if (reduction_type_ == Reduction::REDUCTION_SUM) {
logits_grad[index] = -loss_grad[0] * n_weight;
} else if (reduction_type_ == MsPyEnum::Reduction::MEAN) {
} else if (reduction_type_ == Reduction::MEAN) {
logits_grad[index] = -loss_grad[0] * n_weight / *total_weight;
} else {
logits_grad[index] = -loss_grad[i] * n_weight;

View File

@ -23,7 +23,7 @@
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "nnacl/kernel/nllloss.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace kernel {
@ -57,7 +57,7 @@ class NLLLossGradCpuKernelMod : public NativeCpuKernelMod {
private:
NLLLossStruct nllloss_param_{};
MsPyEnum::Reduction reduction_type_;
Reduction reduction_type_;
int64_t ignore_index_;
};
} // namespace kernel

View File

@ -114,10 +114,10 @@ const std::vector<std::pair<KernelAttr, ResizeLinear1DCpuKernelMod::KernelRunFun
template <typename T>
ResizeLinear1DCpuKernelMod::CoordinateTransformationFunc<T>
ResizeLinear1DCpuKernelMod::ChooseCoordinateTransformationFunc(
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) const {
const std::unordered_map<MsPyEnum::CoordinateTransformationMode, CoordinateTransformationFunc<T>> coordinate_map{
{MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
{MsPyEnum::CoordinateTransformationMode::HALF_PIXEL, HalfPixelFunc<T>()}};
CoordinateTransformMode coordinate_transformation_mode) const {
const std::unordered_map<CoordinateTransformMode, CoordinateTransformationFunc<T>> coordinate_map{
{CoordinateTransformMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
{CoordinateTransformMode::HALF_PIXEL, HalfPixelFunc<T>()}};
return coordinate_map.at(coordinate_transformation_mode);
}
@ -162,9 +162,9 @@ int ResizeLinear1DCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs
out_width_ = LongToSize(output_shape[kIndex2]);
coordinate_transformation_mode_ =
static_cast<MsPyEnum::CoordinateTransformationMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
if (coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS &&
coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::HALF_PIXEL) {
static_cast<CoordinateTransformMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
if (coordinate_transformation_mode_ != CoordinateTransformMode::ALIGN_CORNERS &&
coordinate_transformation_mode_ != CoordinateTransformMode::HALF_PIXEL) {
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";
}
SetWorkSpaceSize(inputs);

View File

@ -24,7 +24,7 @@
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "ops/auto_generate/gen_enum_def.h"
#include "mindapi/base/types.h"
namespace mindspore::kernel {
constexpr auto kUnknown = "Unknown";
@ -60,7 +60,7 @@ class ResizeLinear1DCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
template <typename T>
CoordinateTransformationFunc<T> ChooseCoordinateTransformationFunc(
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) const;
CoordinateTransformMode coordinate_transformation_mode) const;
template <typename T>
void ComputeInterpolationCaches(const size_t out_size, const size_t in_size,
@ -73,8 +73,7 @@ class ResizeLinear1DCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
size_t channel_{0};
size_t in_width_{0};
size_t out_width_{0};
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode_ =
MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS;
CoordinateTransformMode coordinate_transformation_mode_ = CoordinateTransformMode::ALIGN_CORNERS;
};
} // namespace mindspore::kernel

View File

@ -116,10 +116,10 @@ const std::vector<std::pair<KernelAttr, ResizeLinear1DGradCpuKernelMod::KernelRu
template <typename T>
ResizeLinear1DGradCpuKernelMod::CoordinateTransformationFunc<T>
ResizeLinear1DGradCpuKernelMod::ChooseCoordinateTransformationFunc(
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) {
const std::unordered_map<MsPyEnum::CoordinateTransformationMode, CoordinateTransformationFunc<T>> coordinate_map{
{MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
{MsPyEnum::CoordinateTransformationMode::HALF_PIXEL, HalfPixelFunc<T>()}};
CoordinateTransformMode coordinate_transformation_mode) {
const std::unordered_map<CoordinateTransformMode, CoordinateTransformationFunc<T>> coordinate_map{
{CoordinateTransformMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
{CoordinateTransformMode::HALF_PIXEL, HalfPixelFunc<T>()}};
return coordinate_map.at(coordinate_transformation_mode);
}
@ -164,9 +164,9 @@ int ResizeLinear1DGradCpuKernelMod::Resize(const std::vector<KernelTensor *> &in
input_width_ = LongToSize(shape_[kIndex2]);
coordinate_transformation_mode_ =
static_cast<MsPyEnum::CoordinateTransformationMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
if (coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS &&
coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::HALF_PIXEL) {
static_cast<CoordinateTransformMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
if (coordinate_transformation_mode_ != CoordinateTransformMode::ALIGN_CORNERS &&
coordinate_transformation_mode_ != CoordinateTransformMode::HALF_PIXEL) {
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";
}
SetWorkSpaceSize(inputs);

View File

@ -24,7 +24,7 @@
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "ops/auto_generate/gen_enum_def.h"
#include "mindapi/base/types.h"
namespace mindspore::kernel {
constexpr auto kUnknown = "Unknown";
@ -66,7 +66,7 @@ class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod,
template <typename T>
CoordinateTransformationFunc<T> ChooseCoordinateTransformationFunc(
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode);
CoordinateTransformMode coordinate_transformation_mode);
std::string kernel_type_{kUnknown};
bool align_corners_{false};
@ -76,8 +76,7 @@ class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod,
size_t channel_{0};
size_t input_width_{0};
size_t output_width_{0};
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode_ =
MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS;
CoordinateTransformMode coordinate_transformation_mode_ = CoordinateTransformMode::ALIGN_CORNERS;
};
} // namespace mindspore::kernel

View File

@ -19,12 +19,18 @@
#include <map>
#include <string>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "mindapi/base/types.h"
enum class ReductionMode { kNone, kMean, kSum };
static std::map<std::string, ReductionMode> kReductionModeMap{
{"none", ReductionMode::kNone}, {"mean", ReductionMode::kMean}, {"sum", ReductionMode::kSum}};
static std::map<int64_t, ReductionMode> kEnumReductionModeMap{
{static_cast<int64_t>(mindspore::Reduction::NONE), ReductionMode::kNone},
{static_cast<int64_t>(mindspore::Reduction::MEAN), ReductionMode::kMean},
{static_cast<int64_t>(mindspore::Reduction::REDUCTION_SUM), ReductionMode::kSum}};
template <typename T>
CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction,
const T *input_x, const T *input_y, const T *weight, T *loss,

View File

@ -55,21 +55,21 @@ enum class FFTNormMode {
by_root_n, // same as above, but sqrt the product
};
FFTNormMode GetNormModeFromString(const MsPyEnum::NormMode &norm_type, const bool is_inverse) {
if (norm_type == MsPyEnum::NormMode::FORWARD) {
FFTNormMode GetNormModeFromString(const ops::NormMode &norm_type, const bool is_inverse) {
if (norm_type == ops::NormMode::FORWARD) {
return is_inverse ? FFTNormMode::none : FFTNormMode::by_n;
}
if (norm_type == MsPyEnum::NormMode::BACKWARD) {
if (norm_type == ops::NormMode::BACKWARD) {
return is_inverse ? FFTNormMode::by_n : FFTNormMode::none;
}
if (norm_type == MsPyEnum::NormMode::ORTHO) {
if (norm_type == ops::NormMode::ORTHO) {
return FFTNormMode::by_root_n;
}
MS_LOG(ERROR) << "For 'FFTWithSize', the fft norm type " << norm_type << " is unsupported!";
return FFTNormMode::none;
}
double GetNormScale(const MsPyEnum::NormMode &norm_type, const bool is_inverse, const int n) {
double GetNormScale(const ops::NormMode &norm_type, const bool is_inverse, const int n) {
FFTNormMode norm_mode = GetNormModeFromString(norm_type, is_inverse);
if (norm_mode == FFTNormMode::none) {
return 1.0;
@ -115,7 +115,7 @@ bool FFTWithSizeGpuKernelMod::FFTVarietyInResize(const std::vector<KernelTensor
rank_ = inputs[kIndex1]->GetValueWithCheck<int64_t>();
is_inverse_ = inputs[kIndex2]->GetValueWithCheck<bool>();
is_real_ = inputs[kIndex3]->GetValueWithCheck<bool>();
norm_type_ = static_cast<MsPyEnum::NormMode>(inputs[kIndex4]->GetValueWithCheck<int64_t>());
norm_type_ = static_cast<ops::NormMode>(inputs[kIndex4]->GetValueWithCheck<int64_t>());
is_onesided_ = inputs[kIndex5]->GetValueWithCheck<bool>();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);

View File

@ -26,7 +26,7 @@
#include <memory>
#include <algorithm>
#include <functional>
#include "ops/auto_generate/gen_enum_def.h"
#include "ops/op_enum.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fft_with_size_impl.cuh"
@ -89,7 +89,7 @@ class FFTWithSizeGpuKernelMod : public NativeGpuKernelMod {
int64_t rank_{1};
bool is_inverse_{false}; // 0: forward, 1: inverse
bool is_real_{false};
MsPyEnum::NormMode norm_type_{MsPyEnum::NormMode::BACKWARD}; // forward, backward, ortho
ops::NormMode norm_type_{ops::NormMode::BACKWARD}; // forward, backward, ortho
// is_onesided controls whether frequency is halved when signal is real, which means is_real_ is true.
// The default value is true. cufft does not support full freq with real signal. We use cast as a temporary solution.
bool is_onesided_{true};

View File

@ -26,7 +26,7 @@
#include <functional>
#include "mindspore/core/ops/ops_func_impl/grid_sampler_2d.h"
#include "mindspore/core/ops/ops_func_impl/grid_sampler_3d.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindspore/core/ops/op_enum.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh"

View File

@ -24,7 +24,7 @@
#include <algorithm>
#include "mindspore/core/ops/ops_func_impl/grid_sampler_2d_grad.h"
#include "mindspore/core/ops/ops_func_impl/grid_sampler_3d_grad.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindspore/core/ops/op_enum.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_grad_impl.cuh"

View File

@ -42,7 +42,7 @@ int NLLLossGpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs, const
return ret;
}
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
reduction_ = static_cast<ReductionMode>(reduction);
reduction_ = kEnumReductionModeMap[reduction];
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
auto logits_shape = inputs[kIndex0]->GetShapeVector();
label_size_ = logits_shape[0];

View File

@ -45,7 +45,7 @@ int NLLLossGradGpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
return ret;
}
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
reduction_ = static_cast<ReductionMode>(reduction);
reduction_ = kEnumReductionModeMap[reduction];
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
auto logits_shape = inputs[kIndex0]->GetShapeVector();

View File

@ -16,7 +16,7 @@
#include "plugin/device/gpu/kernel/nn/resize_linear_1d_gpu_kernel.h"
#include "mindspore/core/abstract/utils.h"
#include "ops/auto_generate/gen_enum_def.h"
#include "mindapi/base/types.h"
namespace {
constexpr const size_t kResizeLinear1DInputsNum = 3;
@ -58,10 +58,9 @@ int ResizeLinear1DGpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs
out_width_ = output_shape_[kIndex2];
auto coordinate_transformation_mode = inputs.at(kIndex2)->GetValueWithCheck<int64_t>();
if (coordinate_transformation_mode == static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS)) {
if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::ALIGN_CORNERS)) {
mode_ = ResizeLinearCoordinateTransformationMode::ALIGN_CORNERS;
} else if (coordinate_transformation_mode ==
static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::HALF_PIXEL)) {
} else if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::HALF_PIXEL)) {
mode_ = ResizeLinearCoordinateTransformationMode::HALF_PIXEL;
} else {
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";

View File

@ -17,7 +17,7 @@
#include "plugin/device/gpu/kernel/nn/resize_linear_1d_grad_gpu_kernel.h"
#include "mindspore/core/abstract/utils.h"
#include "ops/ops_func_impl/resize_linear_1d_grad.h"
#include "ops/auto_generate/gen_enum_def.h"
#include "mindapi/base/types.h"
namespace {
constexpr const size_t kResizeLinear1DGradInputsNum = 3;
@ -64,10 +64,9 @@ int ResizeLinear1DGradGpuKernelMod::Resize(const std::vector<KernelTensor *> &in
workspace_size_list_.push_back(work_space_size * sizeof(float));
auto coordinate_transformation_mode = inputs.at(kIndex2)->GetValueWithCheck<int64_t>();
if (coordinate_transformation_mode == static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS)) {
if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::ALIGN_CORNERS)) {
mode_ = ResizeLinearCoordinateTransformationMode::ALIGN_CORNERS;
} else if (coordinate_transformation_mode ==
static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::HALF_PIXEL)) {
} else if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::HALF_PIXEL)) {
mode_ = ResizeLinearCoordinateTransformationMode::HALF_PIXEL;
} else {
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";

View File

@ -20,7 +20,6 @@
#include <string>
#include "mindspore/core/ops/nn_ops.h"
#include "mindspore/core/ops/math_ops.h"
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
#include "mindspore/core/ops/op_utils.h"
#include "include/backend/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"

View File

@ -0,0 +1,73 @@
/**
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/core/ops/op_enum.h"
#include "mindspore/core/ops/op_def.h"
#include "mindspore/core/mindapi/base/format.h"
#include "include/common/pybind_api/api_register.h"
namespace mindspore::ops {
void RegOpEnum(py::module *m) {
auto m_sub = m->def_submodule("op_enum", "submodule for op enum");
(void)m_sub.def("str_to_enum", &StringToEnumImpl, "string to enum value");
(void)py::enum_<OP_DTYPE>(*m, "OpDtype", py::arithmetic())
.value("DT_BEGIN", OP_DTYPE::DT_BEGIN)
.value("DT_BOOL", OP_DTYPE::DT_BOOL)
.value("DT_INT", OP_DTYPE::DT_INT)
.value("DT_FLOAT", OP_DTYPE::DT_FLOAT)
.value("DT_NUMBER", OP_DTYPE::DT_NUMBER)
.value("DT_TENSOR", OP_DTYPE::DT_TENSOR)
.value("DT_STR", OP_DTYPE::DT_STR)
.value("DT_ANY", OP_DTYPE::DT_ANY)
.value("DT_TUPLE_BOOL", OP_DTYPE::DT_TUPLE_BOOL)
.value("DT_TUPLE_INT", OP_DTYPE::DT_TUPLE_INT)
.value("DT_TUPLE_FLOAT", OP_DTYPE::DT_TUPLE_FLOAT)
.value("DT_TUPLE_NUMBER", OP_DTYPE::DT_TUPLE_NUMBER)
.value("DT_TUPLE_TENSOR", OP_DTYPE::DT_TUPLE_TENSOR)
.value("DT_TUPLE_STR", OP_DTYPE::DT_TUPLE_STR)
.value("DT_TUPLE_ANY", OP_DTYPE::DT_TUPLE_ANY)
.value("DT_LIST_BOOL", OP_DTYPE::DT_LIST_BOOL)
.value("DT_LIST_INT", OP_DTYPE::DT_LIST_INT)
.value("DT_LIST_FLOAT", OP_DTYPE::DT_LIST_FLOAT)
.value("DT_LIST_NUMBER", OP_DTYPE::DT_LIST_NUMBER)
.value("DT_LIST_TENSOR", OP_DTYPE::DT_LIST_TENSOR)
.value("DT_LIST_STR", OP_DTYPE::DT_LIST_STR)
.value("DT_LIST_ANY", OP_DTYPE::DT_LIST_ANY)
.value("DT_TYPE", OP_DTYPE::DT_TYPE)
.value("DT_END", OP_DTYPE::DT_END);
// There are currently some deficiencies in format, which will be filled in later.
(void)py::enum_<Format>(*m, "FormatEnum", py::arithmetic())
.value("DEFAULT_FORMAT", Format::DEFAULT_FORMAT)
.value("NCHW", Format::NCHW)
.value("NHWC", Format::NHWC)
.value("NHWC4", Format::NHWC4)
.value("HWKC", Format::HWKC)
.value("HWCK", Format::HWCK)
.value("KCHW", Format::KCHW)
.value("CKHW", Format::CKHW)
.value("KHWC", Format::KHWC)
.value("CHWK", Format::CHWK)
.value("HW", Format::HW)
.value("HW4", Format::HW4)
.value("NC", Format::NC)
.value("NC4", Format::NC4)
.value("NC4HW4", Format::NC4HW4)
.value("NCDHW", Format::NCDHW)
.value("NWC", Format::NWC)
.value("NCW", Format::NCW)
.value("NDHWC", Format::NDHWC)
.value("NC8HW8", Format::NC8HW8);
}
} // namespace mindspore::ops

View File

@ -238,7 +238,7 @@ class GEPadMod {
class GEReduction {
public:
static std::string ConvertEnumToString(int64_t id) {
static const std::vector<std::string> reductions = {"none", "mean", "sum", "add"};
static const std::vector<std::string> reductions = {"sum", "mean", "none"};
if (id < 0 || id >= static_cast<int64_t>(reductions.size())) {
MS_LOG(EXCEPTION) << "Invalid reduction " << id;
return "";

View File

@ -86,7 +86,7 @@ enum Format : int64_t {
NUM_OF_FORMAT
};
inline std::string FormatEnumToString(mindspore::Format format) {
inline const std::vector<std::string> &GetFormatNames() {
static std::vector<std::string> names = {
"NCHW",
"NHWC",
@ -148,6 +148,11 @@ inline std::string FormatEnumToString(mindspore::Format format) {
"NYUV_A",
"NCL",
};
return names;
}
inline std::string FormatEnumToString(mindspore::Format format) {
const auto &names = GetFormatNames();
if (format == mindspore::Format::DEFAULT_FORMAT) {
return "DefaultFormat";
}

View File

@ -16,8 +16,6 @@
#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
// TODO(chenfei): remove the TypeId by include the auto generated type id.
// #include "mindspore/core/ops/auto_generate/gen_enum_def.h"
namespace mindspore {
/// \brief TypeId defines data type identifiers.

View File

@ -0,0 +1,164 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/op_enum.h"
#include <algorithm>
#include <utility>
#include "mindapi/base/types.h"
#include "utils/check_convert_utils.h"
#include "mindapi/base/format.h"
namespace mindspore {
namespace ops {
namespace {
using StrToEnumMap = std::unordered_map<std::string, int64_t>;
class RegStringToEnumHelper {
public:
template <typename T>
std::string AddValues(T &&string_to_enum) {
for (const auto &kv : string_to_enum) {
if (string_to_enum_.find(kv.first) != string_to_enum_.end()) {
MS_LOG_EXCEPTION << kv.first << " has been registered!";
}
}
string_to_enum_.merge(std::move(string_to_enum));
return "";
}
const StrToEnumMap &GetValues() { return string_to_enum_; }
private:
StrToEnumMap string_to_enum_;
};
RegStringToEnumHelper reg_string_to_enum_helper;
#define REG_STRING_TO_ENUM(enum_type, ...) \
const auto op_enum_##enum_type = reg_string_to_enum_helper.AddValues(__VA_ARGS__);
// Convert to uppercase uniformly
inline std::string StrToUpper(const std::string &str) {
auto res = str;
for (auto &c : res) {
c = std::toupper(c);
}
return res;
}
// Format
inline std::unordered_map<std::string, int64_t> GetStringToFormatMap() {
const auto &names = GetFormatNames();
std::unordered_map<std::string, int64_t> map{{"DEFAULT_FORMAT", static_cast<int64_t>(Format::DEFAULT_FORMAT)}};
for (size_t i = 0; i < names.size(); ++i) {
map[StrToUpper(names[i])] = static_cast<int64_t>(i);
}
return map;
}
REG_STRING_TO_ENUM(format, GetStringToFormatMap())
// PadMode
StrToEnumMap StrToPadModeMap = {{"PAD", PadMode::PAD}, {"SAME", PadMode::SAME}, {"VALID", PadMode::VALID}};
REG_STRING_TO_ENUM(pad_mode, StrToPadModeMap)
// Reduction
StrToEnumMap StrToReductionMap = {
{"SUM", Reduction::REDUCTION_SUM}, {"MEAN", Reduction::MEAN}, {"NONE", Reduction::NONE}};
REG_STRING_TO_ENUM(reduction, StrToReductionMap)
// Activation
StrToEnumMap StrToActivationMap = {{"NO_ACTIVATION", ActivationType::NO_ACTIVATION},
{"RELU", ActivationType::RELU},
{"SIGMOID", ActivationType::SIGMOID},
{"RELU6", ActivationType::RELU6},
{"ELU", ActivationType::ELU},
{"LEAKY_RELU", ActivationType::LEAKY_RELU},
{"ABS", ActivationType::ABS},
{"RELU1", ActivationType::RELU1},
{"SOFTSIGN", ActivationType::SOFTSIGN},
{"SOFTPLUS", ActivationType::SOFTPLUS},
{"TANH", ActivationType::TANH},
{"SELU", ActivationType::SELU},
{"HSWISH", ActivationType::HSWISH},
{"HSIGMOID", ActivationType::HSIGMOID},
{"THRESHOLDRELU", ActivationType::THRESHOLDRELU},
{"LINEAR", ActivationType::LINEAR},
{"HARD_TANH", ActivationType::HARD_TANH},
{"SIGN", ActivationType::SIGN},
{"SWISH", ActivationType::SWISH},
{"GELU", ActivationType::GELU},
{"GLU", ActivationType::GLU},
{"UNKNOWN", ActivationType::UNKNOWN}};
REG_STRING_TO_ENUM(activation, StrToActivationMap)
// GateOrder
REG_STRING_TO_ENUM(gate_order, StrToEnumMap{{"RZH", GateOrderMode::RZH}, {"ZRH", GateOrderMode::ZRH}})
// CoordinateTransformationMode
StrToEnumMap StrToCoordinateTransformationModeMap = {{"ASYMMETRIC", CoordinateTransformMode::ASYMMETRIC},
{"ALIGN_CORNERS", CoordinateTransformMode::ALIGN_CORNERS},
{"HALF_PIXEL", CoordinateTransformMode::HALF_PIXEL},
{"CROP_AND_RESIZE", CoordinateTransformMode::CROP_AND_RESIZE}};
REG_STRING_TO_ENUM(coordinate_transformation_mode, StrToCoordinateTransformationModeMap)
// PaddingMode
StrToEnumMap StrToPaddingModeMap = {{"CONSTANT", PaddingMode::CONSTANT},
{"REFLECT", PaddingMode::REFLECT},
{"SYMMETRIC", PaddingMode::SYMMETRIC},
{"MODE_RESERVED", PaddingMode::MODE_RESERVED}};
REG_STRING_TO_ENUM(padding_mode, StrToPaddingModeMap)
// Direction
REG_STRING_TO_ENUM(direction, StrToEnumMap{{"UNIDIRECTIONAL", Direction::UNIDIRECTIONAL}})
// CellType
REG_STRING_TO_ENUM(cell_type, StrToEnumMap{{"LSTM", CellType::LSTM}})
// Group
REG_STRING_TO_ENUM(group, StrToEnumMap{{"SYNC_BN_GROUP0", Group::SYNC_BN_GROUP0}})
// InterpolationMode
REG_STRING_TO_ENUM(interpolation_mode,
StrToEnumMap{{"BILINEAR", InterpolationMode::BILINEAR}, {"NEAREST", InterpolationMode::NEAREST}})
// NormMode
StrToEnumMap StrToNormModeMap = {
{"BACKWARD", NormMode::BACKWARD}, {"FORWARD", NormMode::FORWARD}, {"ORTHO", NormMode::ORTHO}};
REG_STRING_TO_ENUM(norm_mode, StrToNormModeMap)
// GridSamplerPaddingMode
StrToEnumMap StrToGridSamplerPaddingMode = {{"ZEROS", GridSamplerPaddingMode::ZEROS},
{"BORDER", GridSamplerPaddingMode::BORDER},
{"REFLECTION", GridSamplerPaddingMode::REFLECTION}};
REG_STRING_TO_ENUM(grid_sampler_padding_mode, StrToGridSamplerPaddingMode)
// KVCacheAlignMode
REG_STRING_TO_ENUM(k_v_cache_align_mode,
StrToEnumMap{{"LEFT", KVCacheAlignMode::LEFT}, {"RIGHT", KVCacheAlignMode::RIGHT}})
} // namespace
int64_t StringToEnumImpl(const std::string &enum_string) {
const auto &string_to_enum_map = reg_string_to_enum_helper.GetValues();
const auto enum_val_iter = string_to_enum_map.find(StrToUpper(enum_string));
if (enum_val_iter == string_to_enum_map.end()) {
MS_LOG_EXCEPTION << "Can not find '" << enum_string << "', please add it";
}
return enum_val_iter->second;
}
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_OP_ENUM_H_
#define MINDSPORE_CORE_OPS_OP_ENUM_H_
#include <string>
#include <vector>
#include <memory>
#include <unordered_map>
#include "mindapi/base/macros.h"
namespace mindspore {
namespace ops {
MIND_API int64_t StringToEnumImpl(const std::string &enum_string);
// Only the current mindspore/core/mindapi/base/types.h and other files do not have
// corresponding enumerations and then add new enumerations.
// The `enum` is used here instead of `enum class` because the current backend enum is
// represented by `int`. The `enum` is more convenient than `enum class` compare with int.
enum Direction : int64_t { UNIDIRECTIONAL = 0 };
enum CellType : int64_t { LSTM = 0 };
enum Group : int64_t { SYNC_BN_GROUP0 = 0 };
enum InterpolationMode : int64_t { BILINEAR = 0, NEAREST = 1 };
enum NormMode : int64_t { BACKWARD = 0, FORWARD = 1, ORTHO = 2 };
enum GridSamplerPaddingMode : int64_t { ZEROS = 0, BORDER = 1, REFLECTION = 2 };
enum KVCacheAlignMode : int64_t { RIGHT = 0, LEFT = 1 };
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ENUM_H_

View File

@ -8,10 +8,10 @@ argmax:
default: -1
prim_init: True
output_type:
dtype: int
dtype: TypeId
default: mstype.int32
prim_init: True
arg_handler: dtype_to_enum
arg_handler: dtype_to_type_id
returns:
output:
dtype: tensor

View File

@ -8,10 +8,10 @@ argmin:
default: -1
prim_init: True
output_type:
dtype: int
dtype: TypeId
default: mstype.int32
prim_init: True
arg_handler: dtype_to_enum
arg_handler: dtype_to_type_id
returns:
output:
dtype: tensor

View File

@ -21,12 +21,12 @@ avg_pool_grad:
dtype: int
default: "'VALID'"
prim_init: True
arg_handler: pad_mode_to_enum
arg_handler: str_to_enum
data_format:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
output:
dtype: tensor

View File

@ -17,12 +17,12 @@ avg_pool:
dtype: int
default: "'VALID'"
prim_init: True
arg_handler: pad_mode_to_enum
arg_handler: str_to_enum
data_format:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
output:
dtype: tensor

View File

@ -28,7 +28,7 @@ batch_norm_grad_grad:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
dx:
dtype: tensor

View File

@ -24,7 +24,7 @@ batch_norm_grad:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
dx:
dtype: tensor

View File

@ -28,7 +28,7 @@ batch_norm_grad_with_activation:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
dx:
dtype: tensor

View File

@ -28,7 +28,7 @@ batch_norm_grad_with_activation:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
dx:
dtype: tensor

View File

@ -26,7 +26,7 @@ batch_norm:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
output_x:
dtype: tensor

View File

@ -26,7 +26,7 @@ batch_norm_with_activation:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
output_x:
dtype: tensor

View File

@ -28,7 +28,7 @@ batch_norm_with_add_and_activation:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
output_x:
dtype: tensor

View File

@ -7,7 +7,7 @@ bias_add_grad:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
output:
dtype: tensor

View File

@ -9,7 +9,7 @@ bias_add:
dtype: int
default: "'NCHW'"
prim_init: True
arg_handler: py_format_to_enum
arg_handler: str_to_enum
returns:
output:
dtype: tensor

View File

@ -6,8 +6,8 @@ eye:
m:
dtype: int
dtype:
dtype: int
arg_handler: dtype_to_enum
dtype: TypeId
arg_handler: dtype_to_type_id
returns:
output:
dtype: tensor

View File

@ -14,7 +14,7 @@ fft_with_size:
prim_init: True
norm:
dtype: int
arg_handler: norm_mode_to_enum
arg_handler: str_to_enum
default: "'backward'"
prim_init: True
onesided:

View File

@ -11,12 +11,12 @@ grid_sampler_2d_grad:
dtype: int
default: "'bilinear'"
prim_init: True
arg_handler: interpolation_mode_to_enum
arg_handler: str_to_enum
padding_mode:
dtype: int
default: "'zeros'"
prim_init: True
arg_handler: grid_sampler_padding_mode_to_enum
arg_handler: str_to_enum
align_corners:
dtype: bool
default: False

View File

@ -9,12 +9,12 @@ grid_sampler_2d:
dtype: int
default: "'bilinear'"
prim_init: True
arg_handler: interpolation_mode_to_enum
arg_handler: str_to_enum
padding_mode:
dtype: int
default: "'zeros'"
prim_init: True
arg_handler: grid_sampler_padding_mode_to_enum
arg_handler: str_to_enum
align_corners:
dtype: bool
default: False

View File

@ -11,12 +11,12 @@ grid_sampler_3d_grad:
dtype: int
default: "'bilinear'"
prim_init: True
arg_handler: interpolation_mode_to_enum
arg_handler: str_to_enum
padding_mode:
dtype: int
default: "'zeros'"
prim_init: True
arg_handler: grid_sampler_padding_mode_to_enum
arg_handler: str_to_enum
align_corners:
dtype: bool
default: False

View File

@ -9,18 +9,18 @@ grid_sampler_3d:
dtype: int
default: "'bilinear'"
prim_init: True
arg_handler: interpolation_mode_to_enum
arg_handler: str_to_enum
padding_mode:
dtype: int
default: "'zeros'"
prim_init: True
arg_handler: grid_sampler_padding_mode_to_enum
arg_handler: str_to_enum
align_corners:
dtype: bool
default: False
prim_init: True
returns:
output:
output:
dtype: tensor
class:
name: GridSampler3D

View File

@ -22,7 +22,7 @@ extract_image_patches:
dtype: int
default: "'VALID'"
prim_init: True
arg_handler: pad_mode_to_enum
arg_handler: str_to_enum
returns:
output:
dtype: tensor

View File

@ -19,7 +19,7 @@ prompt_k_v_cache:
dtype: int
default: 0
prim_init: True
arg_handler: k_v_cache_align_mode_to_enum
arg_handler: str_to_enum
labels:
side_effect_mem: True
returns:

View File

@ -15,7 +15,7 @@ nllloss_grad:
dtype: int
default: "'mean'"
prim_init: True
arg_handler: reduction_to_enum
arg_handler: str_to_enum
ignore_index:
dtype: int
default: -100

View File

@ -11,7 +11,7 @@ nllloss:
dtype: int
default: "'mean'"
prim_init: True
arg_handler: reduction_to_enum
arg_handler: str_to_enum
ignore_index:
dtype: int
default: -100

View File

@ -13,10 +13,10 @@ randperm_v2:
default: 0
prim_init: True
dtype:
dtype: int
dtype: TypeId
default: mstype.int64
prim_init: True
arg_handler: dtype_to_enum
arg_handler: dtype_to_type_id
returns:
output:
dtype: tensor

View File

@ -9,7 +9,7 @@ resize_linear_1d_grad:
dtype: int
default: "'align_corners'"
prim_init: True
arg_handler: coordinate_transformation_mode_to_enum
arg_handler: str_to_enum
returns:
output:
dtype: tensor

View File

@ -10,7 +10,7 @@ resize_linear_1d:
dtype: int
default: "'align_corners'"
prim_init: True
arg_handler: coordinate_transformation_mode_to_enum
arg_handler: str_to_enum
returns:
output:
dtype: tensor

View File

@ -22,7 +22,8 @@
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kDefaultRank = 2;
// As kDefaultRank will have a confilict problem during lite compile, use kDefaultShapeSize.
constexpr size_t kDefaultShapeSize = 2;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
} // namespace
@ -32,7 +33,7 @@ void EigCheckShapeValid(const ShapeVector &input_shape) {
return;
}
if (input_shape.size() < kDefaultRank) {
if (input_shape.size() < kDefaultShapeSize) {
MS_EXCEPTION(ValueError) << "For Eig, x should be at lease 2"
<< ", but got a " << input_shape.size() << "-D Tensor.";
}
@ -52,7 +53,7 @@ void EigCheckShapeValid(const ShapeVector &input_shape) {
BaseShapePtr EigFuncImpl::InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
auto input_shape = input_args[kInputIndex0]->GetShape()->GetShapeVector();
std::vector<BaseShapePtr> shapes_list(kDefaultRank);
std::vector<BaseShapePtr> shapes_list(kDefaultShapeSize);
EigCheckShapeValid(input_shape);
/* infer eigen_value shape */

View File

@ -20,7 +20,6 @@
#include <unordered_map>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "ops/auto_generate/gen_enum_def.h"
namespace mindspore {
namespace ops {

View File

@ -18,7 +18,7 @@
#include <algorithm>
#include <memory>
#include "abstract/dshape.h"
#include "ops/auto_generate/gen_enum_def.h"
#include "mindapi/base/types.h"
#include "ops/op_name.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
@ -86,15 +86,15 @@ BaseShapePtr NLLLossFuncImpl::InferShape(const PrimitivePtr &primitive,
std::vector<abstract::BaseShapePtr>{loss_shape_ptr, total_weight_shape_ptr});
}
auto reduce_value_enum = static_cast<MsPyEnum::Reduction>(reduction_opt.value());
if ((reduce_value_enum == MsPyEnum::Reduction::SUM) || (reduce_value_enum == MsPyEnum::Reduction::MEAN)) {
auto reduce_value_enum = static_cast<Reduction>(reduction_opt.value());
if ((reduce_value_enum == Reduction::REDUCTION_SUM) || (reduce_value_enum == Reduction::MEAN)) {
// shape () means 0D tensor.
auto loss_shape_ptr = std::make_shared<abstract::TensorShape>(loss_shape);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{loss_shape_ptr, total_weight_shape_ptr});
}
if (reduce_value_enum == MsPyEnum::Reduction::NONE) {
if (reduce_value_enum == Reduction::NONE) {
if (logits_shape.size() == DIM_2) {
loss_shape.push_back(
std::max({logits_shape[kInputIndex0], labels_shape[kInputIndex0], abstract::Shape::kShapeDimAny}));

View File

@ -103,6 +103,9 @@ GVAR_DEF(PrimitivePtr, kPrimNPUAntiQuant, std::make_shared<Primitive>("AscendAnt
// Fusion Inference OP
GVAR_DEF(PrimitivePtr, kPrimFFN, std::make_shared<Primitive>("FFN"));
// ToEnum OP
GVAR_DEF(PrimitivePtr, kPrimStringToEnum, std::make_shared<Primitive>("StringToEnum"));
} // namespace prim
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/string_to_enum.h"
#include <memory>
#include <string>
#include <vector>
#include "abstract/ops/primitive_infer_map.h"
#include "ir/dtype/tensor_type.h"
#include "ir/primitive.h"
#include "utils/log_adapter.h"
#include "utils/check_convert_utils.h"
#include "ops/other_ops.h"
#include "mindapi/src/helper.h"
#include "ops/op_enum.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(StringToEnum, BaseOperator);
class MIND_API StringToEnumInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return std::make_shared<abstract::TensorShape>();
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
return kInt64;
}
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const auto &op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("StringToEnum infer", int64_t(input_args.size()), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
const auto &input_value = input_args[0]->GetValue();
MS_EXCEPTION_IF_NULL(input_value);
if (!input_value->isa<StringImm>()) {
MS_LOG_EXCEPTION << "Currently, for " << op_name << ", input value must a string value";
}
const auto &enum_str = GetValue<std::string>(input_value);
const auto enum_int = StringToEnumImpl(enum_str);
return MakeValue(enum_int);
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(StringToEnum, prim::kPrimStringToEnum, StringToEnumInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_
#define MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_
#include <string>
#include "mindapi/base/types.h"
#include "ops/base_operator.h"
namespace mindspore {
namespace ops {
constexpr auto kNameStringToEnum = "StringToEnum";
/// \brief Returns the enum value of the input string.
class MIND_API StringToEnum : public BaseOperator {
public:
MIND_API_BASE_MEMBER(StringToEnum);
/// \brief Constructor.
StringToEnum() : BaseOperator(kNameStringToEnum) { InitIOName({"x"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.StringToEnum for the inputs.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_

View File

@ -20,7 +20,6 @@
#include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h"
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "ops/auto_generate/gen_lite_ops.h"
#include "ops/auto_generate/gen_enum_def.h"
namespace mindspore {
class ConvBiasFusionInoutTest : public ConvFusionInoutTest {

View File

@ -25,7 +25,7 @@ from mindspore.ops.primitive import _primexpr
from mindspore.ops.function import _VmapGeneralRule
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
_bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
from mindspore.ops.auto_generate.gen_enum_def import PyFormat
from mindspore._c_expression import FormatEnum as Format
@vmap_rules_getters.register(G.NLLLossGrad)
@ -306,7 +306,7 @@ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
@vmap_rules_getters.register(G.BatchNormGradGrad)
def get_batchnorm_grad_grad_vmap_rule(prim, axis_size):
"""VmapRule for `BatchNormGradGrad` operation."""
NCHW = PyFormat.NCHW.value
NCHW = Format.NCHW
def vmap_rule(x_bdim, dy_bdim, scale_bdim, mean_bdim, variance_bdim, dout_dx_bdim,
dout_dscale_bdim, dout_dbias_bdim, is_training_bdim, epsilon_bdim, data_format_bdim):
@ -384,7 +384,7 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
bn_min_dim = 3
bn_max_dim = 5
prim_name = prim.name
NHWC = PyFormat.NHWC.value
NHWC = Format.NHWC
def vmap_rule(grad_bdim, x_bdim, scale_bdim, rsv_1_bdim, rsv_2_bdim,
rsv_3_bdim, training_bdim, epsilon_bdim, format_bdim):

View File

@ -29,7 +29,7 @@ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_prepr
_bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
_vmap_clone_prim, _get_reduce_batch_axis
from mindspore.ops.primitive import Primitive
from mindspore.ops.auto_generate.gen_enum_def import PyFormat
from mindspore._c_expression import FormatEnum as Format
@vmap_rules_getters.register(P.ApplyAdaMax)
@ -372,7 +372,7 @@ def get_bias_add_vmap_rule(prim, axis_size):
@constexpr
def get_channal_pos_in_x(d_format, n_dims):
if d_format == PyFormat.NHWC:
if d_format == Format.NHWC:
return n_dims
return 2
@ -424,7 +424,7 @@ def get_bias_add_grad_vmap_rule(prim, axis_size):
"""VmapRule for `BiasAddGrad` operation."""
@constexpr
def get_channal_pos(d_format, x_rank):
if d_format == PyFormat.NHWC:
if d_format == Format.NHWC:
return x_rank
return 2
@ -1249,7 +1249,7 @@ def get_batchnorm_vmap_rule(prim, axis_size):
bn_min_dim = 3
bn_max_dim = 5
prim_name = "BatchNorm"
NCHW = PyFormat.NCHW.value
NCHW = Format.NCHW
def vmap_rule(*inputs):
is_all_none, result = vmap_general_preprocess(prim, *inputs)

View File

@ -19,12 +19,11 @@ Primitive operator classes and operator functional.
A collection of operators to build neural networks or to compute functions.
"""
from . import gen_ops_def, gen_arg_handler, gen_enum_def, gen_arg_dtype_cast
from . import gen_ops_def, gen_arg_handler, gen_arg_dtype_cast
from .gen_ops_def import *
from .gen_arg_handler import *
from .gen_arg_dtype_cast import *
from .gen_enum_def import *
from ..operations.manually_defined.ops_def import *
from .gen_inner_ops_def import *

View File

@ -572,7 +572,7 @@ class BatchNorm(Primitive):
self.is_training = is_training
self.epsilon = epsilon
self.momentum = momentum
self.data_format = handler.py_format_to_enum(data_format)
self.data_format = handler.str_to_enum(data_format)
def __call__(self, *args):
return super().__call__(*args, self.is_training, self.epsilon,

View File

@ -20,7 +20,7 @@ import mindspore as ms
from mindspore import ops
from mindspore.common.tensor import Tensor
from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple
from mindspore.ops.auto_generate.gen_enum_def import OpDtype
from mindspore._c_expression import OpDtype
tensor_to_tuple_ = TensorToTuple()
@ -67,54 +67,53 @@ def tuple_to_tensor(data):
def list_to_tensor(data):
return ops.tuple_to_array(list_to_tuple(data))
# There will be some problems in using OpDtype.xxx directly in GRAPH_MODE, so convert it to int.
# type
PY_DT_TYPE = OpDtype.PY_DT_TYPE.value
DT_TYPE_VAL = int(OpDtype.DT_TYPE)
# scalar
PY_DT_INT = OpDtype.PY_DT_INT.value
PY_DT_FLOAT = OpDtype.PY_DT_FLOAT.value
PY_DT_BOOL = OpDtype.PY_DT_BOOL.value
PY_DT_NUMBER = OpDtype.PY_DT_NUMBER.value
DT_INT_VAL = int(OpDtype.DT_INT)
DT_FLOAT_VAL = int(OpDtype.DT_FLOAT)
DT_BOOL_VAL = int(OpDtype.DT_BOOL)
DT_NUMBER_VAL = int(OpDtype.DT_NUMBER)
# tuple
PY_DT_TUPLE_BOOL = OpDtype.PY_DT_TUPLE_BOOL.value
PY_DT_TUPLE_INT = OpDtype.PY_DT_TUPLE_INT.value
PY_DT_TUPLE_FLOAT = OpDtype.PY_DT_TUPLE_FLOAT.value
PY_DT_TUPLE_NUMBER = OpDtype.PY_DT_TUPLE_NUMBER.value
PY_DT_TUPLE_TENSOR = OpDtype.PY_DT_TUPLE_TENSOR.value
PY_DT_TUPLE_STR = OpDtype.PY_DT_TUPLE_STR.value
PY_DT_TUPLE_ANY = OpDtype.PY_DT_TUPLE_ANY.value
DT_TUPLE_BOOL_VAL = int(OpDtype.DT_TUPLE_BOOL)
DT_TUPLE_INT_VAL = int(OpDtype.DT_TUPLE_INT)
DT_TUPLE_FLOAT_VAL = int(OpDtype.DT_TUPLE_FLOAT)
DT_TUPLE_NUMBER_VAL = int(OpDtype.DT_TUPLE_NUMBER)
DT_TUPLE_TENSOR_VAL = int(OpDtype.DT_TUPLE_TENSOR)
DT_TUPLE_STR_VAL = int(OpDtype.DT_TUPLE_STR)
DT_TUPLE_ANY_VAL = int(OpDtype.DT_TUPLE_ANY)
# list
PY_DT_LIST_BOOL = OpDtype.PY_DT_LIST_BOOL.value
PY_DT_LIST_INT = OpDtype.PY_DT_LIST_INT.value
PY_DT_LIST_FLOAT = OpDtype.PY_DT_LIST_FLOAT.value
PY_DT_LIST_NUMBER = OpDtype.PY_DT_LIST_NUMBER.value
PY_DT_LIST_TENSOR = OpDtype.PY_DT_LIST_TENSOR.value
PY_DT_LIST_STR = OpDtype.PY_DT_LIST_STR.value
PY_DT_LIST_ANY = OpDtype.PY_DT_LIST_ANY.value
DT_LIST_BOOL_VAL = int(OpDtype.DT_LIST_BOOL)
DT_LIST_INT_VAL = int(OpDtype.DT_LIST_INT)
DT_LIST_FLOAT_VAL = int(OpDtype.DT_LIST_FLOAT)
DT_LIST_NUMBER_VAL = int(OpDtype.DT_LIST_NUMBER)
DT_LIST_TENSOR_VAL = int(OpDtype.DT_LIST_TENSOR)
DT_LIST_STR_VAL = int(OpDtype.DT_LIST_STR)
DT_LIST_ANY_VAL = int(OpDtype.DT_LIST_ANY)
# tensor
PY_DT_TENSOR = OpDtype.PY_DT_TENSOR.value
DT_TENSOR_VAL = int(OpDtype.DT_TENSOR)
dtype_to_string = {
PY_DT_INT: "int",
PY_DT_FLOAT: "float",
PY_DT_BOOL: "bool",
PY_DT_NUMBER: "number",
PY_DT_TENSOR: "Tensor",
PY_DT_TUPLE_BOOL: "tuple of bool",
PY_DT_TUPLE_INT: "tuple of int",
PY_DT_TUPLE_FLOAT: "tuple of float",
PY_DT_TUPLE_NUMBER: "tuple of number",
PY_DT_TUPLE_TENSOR: "tuple of tensor",
PY_DT_TUPLE_STR: "tuple of string",
PY_DT_TUPLE_ANY: "tuple of Any",
PY_DT_LIST_BOOL: "list of bool",
PY_DT_LIST_INT: "list of int",
PY_DT_LIST_FLOAT: "list of float",
PY_DT_LIST_NUMBER: "list of number",
PY_DT_LIST_TENSOR: "list of Tensor",
PY_DT_LIST_STR: "list of string",
PY_DT_LIST_ANY: "list of Any"
DT_INT_VAL: "int",
DT_FLOAT_VAL: "float",
DT_BOOL_VAL: "bool",
DT_NUMBER_VAL: "number",
DT_TENSOR_VAL: "Tensor",
DT_TUPLE_BOOL_VAL: "tuple of bool",
DT_TUPLE_INT_VAL: "tuple of int",
DT_TUPLE_FLOAT_VAL: "tuple of float",
DT_TUPLE_NUMBER_VAL: "tuple of number",
DT_TUPLE_TENSOR_VAL: "tuple of tensor",
DT_TUPLE_STR_VAL: "tuple of string",
DT_TUPLE_ANY_VAL: "tuple of Any",
DT_LIST_BOOL_VAL: "list of bool",
DT_LIST_INT_VAL: "list of int",
DT_LIST_FLOAT_VAL: "list of float",
DT_LIST_NUMBER_VAL: "list of number",
DT_LIST_TENSOR_VAL: "list of Tensor",
DT_LIST_STR_VAL: "list of string",
DT_LIST_ANY_VAL: "list of Any"
}
@ -122,34 +121,35 @@ def is_tuple(type_id):
"""
Check type id is tuple.
"""
return type_id in (PY_DT_TUPLE_BOOL, PY_DT_TUPLE_INT, PY_DT_TUPLE_FLOAT, PY_DT_TUPLE_NUMBER,
PY_DT_TUPLE_TENSOR, PY_DT_TUPLE_STR, PY_DT_TUPLE_ANY)
return type_id in (DT_TUPLE_BOOL_VAL, DT_TUPLE_INT_VAL, DT_TUPLE_FLOAT_VAL, DT_TUPLE_NUMBER_VAL,
DT_TUPLE_TENSOR_VAL, DT_TUPLE_STR_VAL, DT_TUPLE_ANY_VAL)
def is_list(type_id):
"""
Check type id is list.
"""
return type_id in (PY_DT_LIST_BOOL, PY_DT_LIST_INT, PY_DT_LIST_FLOAT, PY_DT_LIST_NUMBER, PY_DT_LIST_TENSOR,
PY_DT_LIST_STR, PY_DT_LIST_ANY)
return type_id in (DT_LIST_BOOL_VAL, DT_LIST_INT_VAL, DT_LIST_FLOAT_VAL, DT_LIST_NUMBER_VAL,
DT_LIST_TENSOR_VAL,
DT_LIST_STR_VAL, DT_LIST_ANY_VAL)
def is_number(type_id):
"""
Check type id is number.
"""
return type_id in (PY_DT_INT, PY_DT_FLOAT, PY_DT_BOOL, PY_DT_NUMBER)
return type_id in (DT_INT_VAL, DT_FLOAT_VAL, DT_BOOL_VAL, DT_NUMBER_VAL)
def is_instance_of(data, type_id):
"""
Instead isinstance(obj, type).
"""
if type_id == PY_DT_INT:
if type_id == DT_INT_VAL:
return isinstance(data, int)
if type_id == PY_DT_FLOAT:
if type_id == DT_FLOAT_VAL:
return isinstance(data, float)
if type_id == PY_DT_BOOL:
if type_id == DT_BOOL_VAL:
return isinstance(data, bool)
if is_number(type_id):
return isinstance(data, (int, float, bool))
@ -157,7 +157,7 @@ def is_instance_of(data, type_id):
return isinstance(data, tuple)
if is_list(type_id):
return isinstance(data, list)
if type_id == PY_DT_TENSOR:
if type_id == DT_TENSOR_VAL:
return isinstance(data, Tensor)
return False
@ -192,7 +192,7 @@ def do_type_cast(data, dst_type):
"""Type conversion."""
if is_instance_of(data, dst_type):
return data
if dst_type == PY_DT_FLOAT:
if dst_type == DT_FLOAT_VAL:
if isinstance(data, int):
return int_to_float(data)
elif is_tuple(dst_type):
@ -202,7 +202,7 @@ def do_type_cast(data, dst_type):
return list_to_tuple(data)
if isinstance(data, Tensor):
return tensor_to_tuple(data)
elif dst_type == PY_DT_TENSOR:
elif dst_type == DT_TENSOR_VAL:
if isinstance(data, (int, float, bool)):
return scalar_to_tensor(data)
if isinstance(data, tuple):
@ -211,7 +211,7 @@ def do_type_cast(data, dst_type):
return list_to_tensor(data)
elif is_number(dst_type):
if isinstance(data, Tensor):
if dst_type == PY_DT_INT:
if dst_type == DT_INT_VAL:
data = ops.cast(data, ms.int64)
ret = TensorToScalar()(data)
return ret
@ -222,6 +222,11 @@ def type_it(data, src_type, dst_type):
"""
cast operator argument data type.
"""
if not isinstance(src_type, tuple):
src_type = int(src_type)
else:
src_type = tuple((int(t) for t in src_type))
dst_type = int(dst_type)
if not is_instance_in(data, src_type) and not is_instance_of(data, dst_type):
support_list = get_support_dtype_list(src_type, dst_type)
raise TypeError(f"For type conversion here, only support <{support_list}>, but got {type(data)}.")

View File

@ -14,10 +14,7 @@
# ============================================================================
"""Operator argument handle function."""
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
ops_dtype_to_enum = DtypeToEnum()
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
def to_kernel_size(kernel_size):
"""
@ -131,9 +128,9 @@ def to_3d_paddings(pad):
return pad
raise ValueError(f"For arg 'pad', the value is invalid: {pad}.")
dtype_to_type_id = DtypeToEnum()
def dtype_to_enum(dtype):
"""
convert mindspore.dtype to enum.
"""
return ops_dtype_to_enum(dtype)
# string to enum
# A function for converting str type to enum type are written here,
# but the backend supports str input, and converting str input to enum input is not necessary.
str_to_enum = StringToEnum()

View File

@ -1,241 +0,0 @@
# #enum TypeId
# type_id:
# kTypeUnknown: 0
# #Meta types.
# kMetaTypeBegin: 0
# kMetaTypeType: 1
# kMetaTypeAny: 2
# kMetaTypeObject: 3
# kMetaTypeTypeType: 4
# kMetaTypeProblem: 5
# kMetaTypeExternal: 6
# kMetaTypeNone: 7
# kMetaTypeNull: 8
# kMetaTypeEllipsis: 9
# kMetaTypeEnd: 10
# #
# # Object types
# #
# kObjectTypeBegin: 10
# kObjectTypeNumber: 11
# kObjectTypeString: 12
# kObjectTypeList: 13
# kObjectTypeTuple: 14
# kObjectTypeSlice: 15
# kObjectTypeKeyword: 16
# kObjectTypeTensorType: 17
# kObjectTypeRowTensorType: 18
# kObjectTypeCOOTensorType: 19
# kObjectTypeUndeterminedType: 20
# kObjectTypeClass: 21
# kObjectTypeDictionary: 22
# kObjectTypeFunction: 23
# kObjectTypeJTagged: 24
# kObjectTypeSymbolicKeyType: 25
# kObjectTypeEnvType: 26
# kObjectTypeRefKey: 27
# kObjectTypeRef: 28
# kObjectTypeEnd: 29
# #
# # Number Types
# #
# kNumberTypeBegin: 29
# kNumberTypeBool: 30
# kNumberTypeInt: 31
# kNumberTypeInt8: 32
# kNumberTypeInt16: 33
# kNumberTypeInt32: 34
# kNumberTypeInt64: 35
# kNumberTypeUInt: 36
# kNumberTypeUInt8: 37
# kNumberTypeUInt16: 38
# kNumberTypeUInt32: 39
# kNumberTypeUInt64: 40
# kNumberTypeFloat: 41
# kNumberTypeFloat16: 42
# kNumberTypeFloat32: 43
# kNumberTypeFloat64: 44
# kNumberTypeBFloat16: 45
# kNumberTypeDouble: 46
# kNumberTypeComplex: 47
# kNumberTypeComplex64: 48
# kNumberTypeComplex128: 49
# kNumberTypeInt4: 50
# kNumberTypeGLUInt: 51
# kNumberTypeEnd: 52
# #
# # Monad Types
# #
# kMonadTypeBegin: 52
# kObjectTypeMonad: 53
# kObjectTypeUMonad: 54
# kObjectTypeIOMonad: 55
# kMonadTypeEnd: 56
# #
# # Sparse Types
# #
# kSparseTypeBegin: 56
# kObjectTypeCSRTensorType: 57
# kObjectTypeSparseTensorType: 58
# kObjectTypeMapTensorType: 59
# kSparseTypeEnd: 60
# #enum OpDtype
op_dtype:
PY_DT_BEGIN: 0
PY_DT_BOOL: 1
PY_DT_INT: 2
PY_DT_FLOAT: 3
PY_DT_NUMBER: 4
PY_DT_TENSOR: 5
PY_DT_STR: 6
PY_DT_ANY: 7
PY_DT_TUPLE_BOOL: 8
PY_DT_TUPLE_INT: 9
PY_DT_TUPLE_FLOAT: 10
PY_DT_TUPLE_NUMBER: 11
PY_DT_TUPLE_TENSOR: 12
PY_DT_TUPLE_STR: 13
PY_DT_TUPLE_ANY: 14
PY_DT_LIST_BOOL: 15
PY_DT_LIST_INT: 16
PY_DT_LIST_FLOAT: 17
PY_DT_LIST_NUMBER: 18
PY_DT_LIST_TENSOR: 19
PY_DT_LIST_STR: 20
PY_DT_LIST_ANY: 21
PY_DT_TYPE: 22
PY_DT_END: 23
# enum Format
# the value must be consistent with enum defined in `mindspore/core/mindapi/base/format.h`
py_format:
NCHW: 0
NHWC: 1
NHWC4: 2
HWKC: 3
HWCK: 4
KCHW: 5
CKHW: 6
KHWC: 7
CHWK: 8
HW: 9
HW4: 10
NC: 11
NC4: 12
NC4HW4: 13
NCDHW: 14
NWC: 15
NCW: 16
NDHWC: 17
NC8HW8: 18
FRACTAL_NZ: 19
ND: 20
NC1HWC0: 21
FRACTAL_Z: 22
NC1C0HWPAD: 23
NHWC1C0: 24
FSR_NCHW: 25
FRACTAL_DECONV: 26
C1HWNC0: 27
FRACTAL_DECONV_TRANSPOSE: 28
FRACTAL_DECONV_SP_STRIDE_TRANS: 29
NC1HWC0_C04: 30
FRACTAL_Z_C04: 31
CHWN: 32
FRACTAL_DECONV_SP_STRIDE8_TRANS: 33
HWCN: 34
NC1KHKWHWC0: 35
BN_WEIGHT: 36
FILTER_HWCK: 37
LOOKUP_LOOKUPS: 38
LOOKUP_KEYS: 39
LOOKUP_VALUE: 40
LOOKUP_OUTPUT: 41
LOOKUP_HITS: 42
C1HWNCoC0: 43
MD: 44
FRACTAL_ZZ: 45
DHWCN: 46
NDC1HWC0: 47
FRACTAL_Z_3D: 48
CN: 49
DHWNC: 50
FRACTAL_Z_3D_TRANSPOSE: 51
FRACTAL_ZN_LSTM: 52
FRACTAL_Z_G: 53
ND_RNN_BIAS: 54
FRACTAL_ZN_RNN: 55
NYUV: 56
NYUV_A: 57
NCL: 58
# enum PadMode
pad_mode:
PAD: 0
SAME: 1
VALID: 2
# enum Reduction
reduction:
NONE: 0
MEAN: 1
SUM: 2
ADD: 3
# enum Direction
direction:
UNIDIRECTIONAL: 0
# enum Activation
activation:
TANH: 0
# enum GateOrder
gate_order:
RZH: 0
ZRH: 1
# enum CellType
cell_type:
LSTM: 0
# enum Group
group:
SYNC_BN_GROUP0: 0
# enum InterpolationMode
interpolation_mode:
BILINEAR: 0
NEAREST: 1
# enum NormMode
norm_mode:
BACKWARD: 0
FORWARD: 1
ORTHO: 2
# enum GridSamplerPaddingMode
grid_sampler_padding_mode:
ZEROS: 0
BORDER: 1
REFLECTION: 2
# enum PadFillingMode
pad_filling_mode:
CONSTANT: 0
REFLECT: 1
EDGE: 2
CIRCULAR: 3
SYMMETRIC: 4
# enum CoordinateTransformationMode
coordinate_transformation_mode:
ALIGN_CORNERS: 0
HALF_PIXEL: 1
# enum KVCacheAlignMode
k_v_cache_align_mode:
RIGHT: 0
LEFT: 1

View File

@ -20,7 +20,8 @@ import re
import shutil
import pathlib
import gen_utils
from gen_utils import py_licence_str, cc_license_str, check_change_and_replace_file, merge_files, safe_load_yaml
from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files,
safe_load_yaml, convert_dtype_str)
from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes
import template
from template import CppTemplate
@ -353,6 +354,12 @@ def {func_name}({', '.join(arg for arg in func_args)}):
return gen_py
def get_dtype(arg_info):
dtype = arg_info.get('dtype')
# Currently, TypeId is represented by int
if dtype == 'TypeId':
dtype = 'int'
return dtype
def process_args(args):
"""
@ -365,7 +372,7 @@ def process_args(args):
init_args_with_default = []
args_handlers = {}
for arg_name, arg_info in args.items():
dtype = arg_info.get('dtype')
dtype = get_dtype(arg_info)
default_value = arg_info.get('default')
has_default = 'default' in arg_info.keys()
is_optional = arg_info.get('default') == "None" if has_default else False
@ -621,7 +628,7 @@ namespace mindspore::ops {
if not is_prim_init:
continue
dtype = arg_info.get('dtype')
dtype = get_dtype(arg_info)
if dtype == "str":
dtype = "std::string"
if dtype == "tuple[int]":
@ -671,8 +678,8 @@ std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
for i, (arg_name, arg_info) in enumerate(args.items()):
args_dict[arg_name] = i
cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
dtype = arg_info.get('dtype')
cc_dtype_str = 'DT_' + dtype.replace('[', '_').replace(']', '').upper()
dtype = get_dtype(arg_info)
cc_dtype_str = convert_dtype_str(dtype)
is_prim_init = 1 if arg_info.get('prim_init') else 0
arg_handler = arg_info.get('arg_handler')
@ -723,7 +730,7 @@ from mindspore.common._decorator import deprecated
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it
from mindspore.ops.auto_generate.gen_arg_handler import *
from mindspore.ops.auto_generate.gen_enum_def import OpDtype
from mindspore._c_expression import OpDtype
from mindspore.common._stub_tensor import _convert_stub
"""
@ -891,67 +898,10 @@ def generate_aclnn_reg_file(work_path, yaml_str):
check_change_and_replace_file(register_file, tmp_register_file)
eum_py_header = f"""
\"\"\"Operator argument enum definition.\"\"\"
from enum import Enum
"""
eum_cc_header = f"""
#ifndef MINDSPORE_CORE_OPS_GEN_ENUM_DEF_
#define MINDSPORE_CORE_OPS_GEN_ENUM_DEF_
#include <cstdint>
namespace mindspore::MsPyEnum {{
"""
eum_cc_end = f"""}} // namespace mindspore::MsPyEnum
#endif // MINDSPORE_CORE_OPS_GEN_ENUM_DEF_
"""
def generate_enum_code(yaml_data):
def generate_arg_handler_files(work_path):
"""
Generate python and c++ enum definition
Generate arg handler files.
"""
gen_eum_py_func = ''
gen_eum_py_def = eum_py_header
gen_eum_cc_def = eum_cc_header
for enum_name, enum_data in yaml_data.items():
class_name = ''.join(word.capitalize() for word in enum_name.split('_'))
gen_eum_py_func += f"""\n
def {enum_name}_to_enum({enum_name}_str):
\"""
convert {enum_name} string to enum.
\"""
if not isinstance({enum_name}_str, str):
raise TypeError(f"The {enum_name} should be string, but got {{{enum_name}_str}}")
{enum_name}_str = {enum_name}_str.upper()\n"""
gen_eum_py_def += f"""\n\nclass {class_name}(Enum):\n"""
gen_eum_cc_def += f"""enum {class_name} : int64_t {{\n"""
for enum_key, enum_value in enum_data.items():
gen_eum_py_func += f""" if {enum_name}_str == "{enum_key}":
return {enum_value}\n"""
gen_eum_py_def += f""" {enum_key} = {enum_value}\n"""
gen_eum_cc_def += f""" {enum_key} = {enum_value},\n"""
gen_eum_py_func += f""" raise ValueError(f"Invalid {class_name}: {{{enum_name}_str}}")\n"""
gen_eum_cc_def += f"""}};\n\n"""
gen_eum_cc_def += eum_cc_end
return gen_eum_py_func, gen_eum_py_def, gen_eum_cc_def
def generate_enum_files(work_path):
"""
Generate python function and c++ definition from enum yaml.
"""
enum_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/enum.yaml')
yaml_str = safe_load_yaml(enum_yaml_path)
py_enum_func, py_enum_def, cc_enum_def = generate_enum_code(yaml_str)
dst_dir = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate')
src_arg_handler_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_handler.py')
dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py')
@ -959,8 +909,6 @@ def generate_enum_files(work_path):
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path)
with open(tmp_dst_arg_handler_path, 'a') as py_file:
py_file.write(py_enum_func)
check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path)
src_arg_dtype_cast_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_dtype_cast.py')
@ -969,18 +917,6 @@ def generate_enum_files(work_path):
shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
enum_def_py_path = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate/gen_enum_def.py')
tmp_enum_def_py_path = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate/tmp_gen_enum_def.py')
with open(tmp_enum_def_py_path, 'w') as cc_file:
cc_file.write(py_licence_str + py_enum_def)
check_change_and_replace_file(enum_def_py_path, tmp_enum_def_py_path)
enum_def_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_enum_def.h')
tmp_enum_def_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_enum_def.h')
with open(tmp_enum_def_cc_path, 'w') as cc_file:
cc_file.write(cc_license_str + cc_enum_def)
check_change_and_replace_file(enum_def_cc_path, tmp_enum_def_cc_path)
def main():
current_path = os.path.dirname(os.path.abspath(__file__))
@ -1005,8 +941,8 @@ def main():
cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/')
pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True)
# generate enum code from enum.yaml
generate_enum_files(work_path)
# generate arg_handler files
generate_arg_handler_files(work_path)
# generate ops python files
generate_ops_py_files(work_path, safe_load_yaml(ops_yaml_path), safe_load_yaml(doc_yaml_path), "gen")

View File

@ -16,6 +16,7 @@
from mindspore.ops.primitive import Primitive, prim_attr_register
from mindspore._c_expression import typing
from mindspore._c_expression import op_enum
class DtypeToEnum(Primitive):
@ -41,3 +42,28 @@ class DtypeToEnum(Primitive):
if not isinstance(dtype, typing.Type):
raise TypeError(f"For dtype_to_enum function, the input should be mindpsore dtype, but got {dtype}.")
return typing.type_to_type_id(dtype)
class StringToEnum(Primitive):
r"""
Convert string to enum.
Inputs:
- **enum_str** (str) - The str data.
Outputs:
An integer.
Supported Platforms:
``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize"""
def __call__(self, enum_str):
"""Run in PyNative mode"""
if not isinstance(enum_str, str):
raise TypeError(f"For StringToEnum op, the input should be a str, but got {type(enum_str)}.")
return op_enum.str_to_enum(enum_str)

View File

@ -53,30 +53,36 @@ cc_license_str = f"""/**
* limitations under the License.
*/"""
def convert_dtype_str(dtype_str):
"""
Convert dtype str to expression in ops file
"""
return 'DT_' + dtype_str.replace('[', '_').replace(']', '').upper()
def get_type_str(type_str):
"""
Get the unified type str for operator arg dtype.
"""
# add more type here
type_kind_dict = {
'int': 'OpDtype.PY_DT_INT',
'float': 'OpDtype.PY_DT_FLOAT',
'bool': 'OpDtype.PY_DT_BOOL',
'number': 'OpDtype.PY_DT_NUMBER',
'tuple[int]': 'OpDtype.PY_DT_TUPLE_ANY',
'tuple[float]': 'OpDtype.PY_DT_TUPLE_ANY',
'tuple[bool]': 'OpDtype.PY_DT_TUPLE_ANY',
'tuple[tensor]': 'OpDtype.PY_DT_TUPLE_ANY',
'list[int]': 'OpDtype.PY_DT_LIST_ANY',
'list[float]': 'OpDtype.PY_DT_LIST_ANY',
'list[bool]': 'OpDtype.PY_DT_LIST_ANY',
'list[tensor]': 'OpDtype.PY_DT_LIST_ANY',
'tensor': 'OpDtype.PY_DT_TENSOR',
'type': 'OpDtype.PY_DT_TYPE',
type_kind_set = {
'int',
'float',
'bool',
'number',
'tuple[int]',
'tuple[float]',
'tuple[bool]',
'tuple[tensor]',
'list[int]',
'list[float]',
'list[bool]',
'list[tensor]',
'tensor',
'type',
}
if type_str in type_kind_dict:
return type_kind_dict[type_str]
if type_str in type_kind_set:
return "OpDtype." + convert_dtype_str(type_str)
raise TypeError(f"""Unsupported type {type_str} for args.""")
@ -158,10 +164,10 @@ def get_assign_str_by_type_it(arg_info, arg_name, dtype):
type_cast_tuple = tuple(ct.strip() for ct in type_cast.split(","))
assign_str += f'type_it({arg_name}, '
if len(type_cast_tuple) == 1:
assign_str += get_type_str(type_cast_tuple[0]) + '.value, '
assign_str += get_type_str(type_cast_tuple[0]) + ', '
else:
assign_str += '(' + ', '.join(get_type_str(ct) + '.value' for ct in type_cast_tuple) + '), '
assign_str += get_type_str(dtype) + '.value)'
assign_str += '(' + ', '.join(get_type_str(ct) for ct in type_cast_tuple) + '), '
assign_str += get_type_str(dtype) + ')'
else:
assign_str = arg_name
return assign_str

View File

@ -16,7 +16,7 @@
#include "ops/test_value_utils.h"
#include "ops/auto_generate/gen_ops_name.h"
#include "ops/auto_generate/gen_enum_def.h"
#include "ops/op_enum.h"
#include "ops/test_ops.h"
#include "ops/ops_func_impl/fft_with_size.h"
@ -42,7 +42,7 @@ TEST_P(TestFFT, dyn_shape) {
auto signal_ndim = param.signal_ndim == -1 ? Any->ToAbstract() : CreatePyInt(param.signal_ndim)->ToAbstract();
auto inverse = CreateScalar(param.inverse)->ToAbstract();
auto real = CreateScalar(param.real)->ToAbstract();
auto norm = CreateScalar(static_cast<int64_t>(MsPyEnum::NormMode::BACKWARD))->ToAbstract();
auto norm = CreateScalar(static_cast<int64_t>(ops::NormMode::BACKWARD))->ToAbstract();
auto onesided = CreateScalar(param.onesided)->ToAbstract();
auto signal_sizes = param.signal_sizes->ToAbstract();

View File

@ -29,6 +29,7 @@
#include "ops/test_value_utils.h"
#include "utils/ms_context.h"
#include "utils/tensor_construct_utils.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
@ -99,31 +100,238 @@ TEST_P(TestNLLLoss, dyn_shape) {
INSTANTIATE_TEST_CASE_P(
TestNLLLossGroup, TestNLLLoss,
testing::Values(
NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 0, true, {{2}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, 3},
kFloat32,
{-1},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
true,
{{-1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, 3},
kFloat32,
{-1},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::MEAN),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, 3},
kFloat32,
{-1},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::REDUCTION_SUM),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, 3},
kFloat32,
{-2},
kInt32,
{-2},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
true,
{{-1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, 3},
kFloat32,
{-2},
kInt32,
{-2},
kFloat32,
static_cast<int64_t>(Reduction::MEAN),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, 3},
kFloat32,
{-2},
kInt32,
{-2},
kFloat32,
static_cast<int64_t>(Reduction::REDUCTION_SUM),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-2},
kFloat32,
{-2},
kInt32,
{-2},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
true,
{{-1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-2},
kFloat32,
{-2},
kInt32,
{-2},
kFloat32,
static_cast<int64_t>(Reduction::MEAN),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-2},
kFloat32,
{-2},
kInt32,
{-2},
kFloat32,
static_cast<int64_t>(Reduction::REDUCTION_SUM),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3},
kFloat32,
{2},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
true,
{{2}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3},
kFloat32,
{2},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::MEAN),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3},
kFloat32,
{2},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::REDUCTION_SUM),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-1, -1},
kFloat32,
{-1},
kInt32,
{-1},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
true,
{{-1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, -1},
kFloat32,
{-1},
kInt32,
{-1},
kFloat32,
static_cast<int64_t>(Reduction::MEAN),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, -1},
kFloat32,
{-1},
kInt32,
{-1},
kFloat32,
static_cast<int64_t>(Reduction::REDUCTION_SUM),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{-2},
kFloat32,
{-1},
kInt32,
{-1},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
true,
{{-1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-2},
kFloat32,
{-1},
kInt32,
{-1},
kFloat32,
static_cast<int64_t>(Reduction::MEAN),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-2},
kFloat32,
{-1},
kInt32,
{-1},
kFloat32,
static_cast<int64_t>(Reduction::REDUCTION_SUM),
true,
{{}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3, 4}, kFloat32, {2}, kInt32, {3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2, 3}, kInt32, {3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {2, 3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {2}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
NLLLossParams{{2, 3}, kFloat32, {3}, kInt32, {2}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}}));
NLLLossParams{{2, 3, 4},
kFloat32,
{2},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
false,
{{1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3},
kFloat32,
{2, 3},
kInt32,
{3},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
false,
{{1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3},
kFloat32,
{2},
kInt32,
{2, 3},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
false,
{{1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3},
kFloat32,
{2},
kInt32,
{2},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
false,
{{1}, {}},
{kFloat32, kFloat32}},
NLLLossParams{{2, 3},
kFloat32,
{3},
kInt32,
{2},
kFloat32,
static_cast<int64_t>(Reduction::NONE),
false,
{{1}, {}},
{kFloat32, kFloat32}}));
} // namespace ops
} // namespace mindspore