refractor some python enums in yaml to C++
This commit is contained in:
parent
d8301d1dba
commit
60e4c1183f
|
@ -119,7 +119,6 @@ mindspore/core/ops/auto_generate/gen_lite_ops.h
|
|||
mindspore/core/ops/auto_generate/gen_lite_ops.cc
|
||||
mindspore/core/ops/auto_generate/gen_ops_name.h
|
||||
mindspore/core/ops/auto_generate/gen_ops_primitive.h
|
||||
mindspore/core/ops/auto_generate/gen_enum_def.h
|
||||
mindspore/core/ops/auto_generate/gen_ops_def.cc
|
||||
mindspore/core/ops/auto_generate/gen_ops_def.h
|
||||
mindspore/python/mindspore/ops_generate/ops.yaml
|
||||
|
@ -129,7 +128,6 @@ mindspore/python/mindspore/ops_generate/inner_ops_doc.yaml
|
|||
mindspore/python/mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py
|
||||
mindspore/python/mindspore/ops/auto_generate/gen_arg_handler.py
|
||||
mindspore/python/mindspore/ops/auto_generate/gen_arg_dtype_cast.py
|
||||
mindspore/python/mindspore/ops/auto_generate/gen_enum_def.py
|
||||
mindspore/python/mindspore/ops/auto_generate/gen_inner_ops_def.py
|
||||
mindspore/python/mindspore/ops/auto_generate/gen_ops_def.py
|
||||
mindspore/python/mindspore/ops/auto_generate/gen_pyboost_func.py
|
||||
|
|
|
@ -107,8 +107,8 @@ using ArgHandlerFunc = std::function<ValuePtr(const ValuePtr &)>;
|
|||
|
||||
ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) {
|
||||
static const std::unordered_map<std::string, ArgHandlerFunc> arg_handler_funcs = {
|
||||
{"py_format_to_enum", EnumToFormat},
|
||||
{"dtype_to_enum", EnumToDtype},
|
||||
{"str_to_enum", EnumToFormat},
|
||||
{"dtype_to_type_id", EnumToDtype},
|
||||
};
|
||||
if (arg_handler_funcs.find(arg_handler) != arg_handler_funcs.end()) {
|
||||
return arg_handler_funcs.at(arg_handler);
|
||||
|
@ -119,8 +119,8 @@ ArgHandlerFunc GetArgHandlerFunc(const std::string &arg_handler) {
|
|||
|
||||
ArgHandlerFunc GetOppArgHandlerFunc(const std::string &arg_handler) {
|
||||
static const std::unordered_map<std::string, ArgHandlerFunc> opp_arg_handler_funcs = {
|
||||
{"py_format_to_enum", FormatToEnum},
|
||||
{"dtype_to_enum", DtypeToEnum},
|
||||
{"str_to_enum", FormatToEnum},
|
||||
{"dtype_to_type_id", DtypeToEnum},
|
||||
};
|
||||
if (opp_arg_handler_funcs.find(arg_handler) != opp_arg_handler_funcs.end()) {
|
||||
return opp_arg_handler_funcs.at(arg_handler);
|
||||
|
|
|
@ -70,6 +70,10 @@ namespace abstract {
|
|||
void RegPrimitiveFrontEval();
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace ops {
|
||||
void RegOpEnum(py::module *m);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_INCLUDE_COMMON_PYBIND_API_API_REGISTER_H_
|
||||
|
|
|
@ -156,6 +156,7 @@ void RegModule(py::module *m) {
|
|||
mindspore::abstract::RegPrimitiveFrontEval();
|
||||
#endif
|
||||
mindspore::expander::RegPackExpanderPy(m);
|
||||
mindspore::ops::RegOpEnum(m);
|
||||
}
|
||||
|
||||
void RegModuleHelper(py::module *m) {
|
||||
|
|
|
@ -137,6 +137,11 @@ AnfNodePtr GetNodeAfterArgHandler(const AnfNodePtr &node, const ops::OpInputArg
|
|||
return node;
|
||||
}
|
||||
const auto arg_handler_func = prim::GetPythonOps(op_arg.arg_handler_, parse::PYTHON_MOD_PRIMITIVE_ARG_HANDLER_MODULE);
|
||||
if (arg_handler_func->isa<Primitive>()) {
|
||||
auto arg_handler_fg = dyn_cast<Primitive>(arg_handler_func);
|
||||
MS_EXCEPTION_IF_NULL(arg_handler_fg);
|
||||
return fg->NewCNodeInOrder({NewValueNode(arg_handler_fg), node});
|
||||
}
|
||||
auto arg_handler_fg = dyn_cast<FuncGraph>(arg_handler_func);
|
||||
MS_EXCEPTION_IF_NULL(arg_handler_fg);
|
||||
arg_handler_fg->set_manager(fg->manager());
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "plugin/device/cpu/kernel/grid_sampler_2d_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindspore/core/ops/op_enum.h"
|
||||
|
||||
namespace {
|
||||
const size_t kDataSizeThreshold = 64 * 1024;
|
||||
|
@ -119,7 +119,7 @@ void GridSampler2DCpuKernelMod::ComputeTask(const T *x_addr, const T *grid_addr,
|
|||
auto x_ptr_NC = out_iter[kZero] * x_stride_[kZero];
|
||||
auto output_ptr_NCDHW = out_iter[kZero] * output_stride_[kZero] + out_iter[kTwo] * output_stride_[kTwo] +
|
||||
out_iter[kThree] * output_stride_[kThree];
|
||||
if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
|
||||
if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
|
||||
int64_t x_tnw = static_cast<int64_t>(std::floor(x));
|
||||
int64_t y_tnw = static_cast<int64_t>(std::floor(y));
|
||||
int64_t x_tne = x_tnw + 1;
|
||||
|
@ -153,7 +153,7 @@ void GridSampler2DCpuKernelMod::ComputeTask(const T *x_addr, const T *grid_addr,
|
|||
output_addr[output_ptr_NCDHW] += x_addr[x_index] * tse;
|
||||
}
|
||||
}
|
||||
} else if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
|
||||
} else if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
|
||||
int64_t x_nearest = static_cast<int64_t>(std::round(x));
|
||||
int64_t y_nearest = static_cast<int64_t>(std::round(y));
|
||||
for (size_t c = 0; c < out_c; c++, x_ptr_NC += x_stride_[kOne], output_ptr_NCDHW += output_stride_[kOne]) {
|
||||
|
@ -221,9 +221,9 @@ void GridSampler2DCpuKernelMod::Call2Half(const float16 *x_data_addr, float16 *y
|
|||
auto x_ptr_NC = y_iter[0] * x_stride[0];
|
||||
auto y_ptr_NCHW = y_iter[0] * y_stride[0] + y_iter[2] * y_stride[2] + y_iter[3] * y_stride[3];
|
||||
|
||||
if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
|
||||
if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
|
||||
BilinearHalf(x, y, x_data_addr, y_data_addr, y_c, x_dims, y_stride, x_stride, x_ptr_NC, y_ptr_NCHW);
|
||||
} else if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
|
||||
} else if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
|
||||
NearestHalf(x, y, x_data_addr, y_data_addr, y_c, x_dims, y_stride, x_stride, x_ptr_NC, y_ptr_NCHW);
|
||||
}
|
||||
}
|
||||
|
@ -333,9 +333,9 @@ T GridSampler2DCpuKernelMod::GridSamplerComputeSourceIndex(T coord, int64_t size
|
|||
} else {
|
||||
coord = ((coord + 1.f) * size - 1) / num2;
|
||||
}
|
||||
if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
|
||||
if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
|
||||
coord = std::min(static_cast<T>(size - 1), std::max(coord, static_cast<T>(0)));
|
||||
} else if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) {
|
||||
} else if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::REFLECTION)) {
|
||||
if (align_corners) {
|
||||
coord = ReflectCoordinates(coord, 0, num2 * (size - 1));
|
||||
} else {
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "plugin/device/cpu/kernel/grid_sampler_2d_grad_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindspore/core/ops/op_enum.h"
|
||||
|
||||
namespace {
|
||||
const size_t kDataSizeThreshold = 64 * 1024;
|
||||
|
@ -93,14 +93,14 @@ void GridSampler2DGradCpuKernelMod::ComputeTask(const std::vector<KernelTensor *
|
|||
auto grid_data_addr = static_cast<T *>(inputs[kTwo]->device_ptr());
|
||||
auto dx_data_addr = static_cast<T *>(outputs[kZero]->device_ptr());
|
||||
auto dgrid_data_addr = static_cast<T *>(outputs[kOne]->device_ptr());
|
||||
if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
|
||||
if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
|
||||
interp = GridSamplerInterpolation::Bilinear;
|
||||
} else if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
|
||||
} else if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
|
||||
interp = GridSamplerInterpolation::Nearest;
|
||||
}
|
||||
if (padding_mode_ == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::ZEROS)) {
|
||||
if (padding_mode_ == static_cast<int64_t>(ops::GridSamplerPaddingMode::ZEROS)) {
|
||||
padding = GridSamplerPadding::Zeros;
|
||||
} else if (padding_mode_ == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
|
||||
} else if (padding_mode_ == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
|
||||
padding = GridSamplerPadding::Border;
|
||||
} else {
|
||||
padding = GridSamplerPadding::Reflection;
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "plugin/device/cpu/kernel/grid_sampler_3d_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindspore/core/ops/op_enum.h"
|
||||
|
||||
namespace {
|
||||
const size_t kDataSizeThreshold = 64 * 1024;
|
||||
|
@ -117,7 +117,7 @@ void GridSampler3DCpuKernelMod::ComputeTask(T *x_addr, T *grid_addr, T *output_a
|
|||
auto x_ptr_NC = out_iter[kZero] * x_stride_[kZero];
|
||||
auto output_ptr_NCDHW = out_iter[kZero] * output_stride_[kZero] + out_iter[kTwo] * output_stride_[kTwo] +
|
||||
out_iter[kThree] * output_stride_[kThree] + out_iter[kFour] * output_stride_[kFour];
|
||||
if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
|
||||
if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
|
||||
int64_t x_tnw = static_cast<int64_t>(std::floor(x));
|
||||
int64_t y_tnw = static_cast<int64_t>(std::floor(y));
|
||||
int64_t z_tnw = static_cast<int64_t>(std::floor(z));
|
||||
|
@ -176,7 +176,7 @@ void GridSampler3DCpuKernelMod::ComputeTask(T *x_addr, T *grid_addr, T *output_a
|
|||
output_addr[output_ptr_NCDHW] += x_addr[x_index] * bse;
|
||||
}
|
||||
}
|
||||
} else if (interpolation_mode_ == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
|
||||
} else if (interpolation_mode_ == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
|
||||
int64_t x_nearest = static_cast<int64_t>(std::round(x));
|
||||
int64_t y_nearest = static_cast<int64_t>(std::round(y));
|
||||
int64_t z_nearest = static_cast<int64_t>(std::round(z));
|
||||
|
@ -220,9 +220,9 @@ T GridSampler3DCpuKernelMod::grid_sampler_compute_source_index(T coord, int64_t
|
|||
} else {
|
||||
coord = ((coord + 1.f) * size - kOne) / kTwo;
|
||||
}
|
||||
if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
|
||||
if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
|
||||
coord = std::min(static_cast<T>(static_cast<size_t>(size) - kOne), std::max(coord, static_cast<T>(kZero)));
|
||||
} else if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) {
|
||||
} else if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::REFLECTION)) {
|
||||
if (align_corners) {
|
||||
coord = reflect_coordinates(coord, static_cast<int64_t>(kZero), kTwo * (size - kOne));
|
||||
} else {
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "plugin/device/cpu/kernel/grid_sampler_3d_grad_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindspore/core/ops/op_enum.h"
|
||||
|
||||
namespace {
|
||||
const size_t kDataSizeThreshold = 64 * 1024;
|
||||
|
@ -214,7 +214,7 @@ void GridSampler3DGradCpuKernelMod::ComputeTask(T *grad_addr, T *x_addr, T *grid
|
|||
x = grid_sampler_compute_source_index_set_grad(x, x_shape_[kFour], padding_mode, align_corners_, &gx_mult);
|
||||
y = grid_sampler_compute_source_index_set_grad(y, x_shape_[kThree], padding_mode, align_corners_, &gy_mult);
|
||||
z = grid_sampler_compute_source_index_set_grad(z, x_shape_[kTwo], padding_mode, align_corners_, &gz_mult);
|
||||
if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::BILINEAR)) {
|
||||
if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::BILINEAR)) {
|
||||
size_t grad_ptr_NCDHW =
|
||||
n * grad_stride_[kZero] + d * grad_stride_[kTwo] + h * grad_stride_[kThree] + w * grad_stride_[kFour];
|
||||
size_t dx_ptr_NC = n * dx_stride_[kZero], x_ptr_NC = n * x_stride_[kZero];
|
||||
|
@ -223,7 +223,7 @@ void GridSampler3DGradCpuKernelMod::ComputeTask(T *grad_addr, T *x_addr, T *grid
|
|||
std::vector<T> mult = {gx_mult, gy_mult, gz_mult};
|
||||
std::vector<size_t> ptr = {grad_ptr_NCDHW, x_ptr_NC, dx_ptr_NC, dgrid_ptr_NDHW};
|
||||
BilinearKernel<T>(addr, location, mult, ptr);
|
||||
} else if (interpolation_mode == static_cast<int64_t>(MsPyEnum::InterpolationMode::NEAREST)) {
|
||||
} else if (interpolation_mode == static_cast<int64_t>(ops::InterpolationMode::NEAREST)) {
|
||||
int64_t x_nearest = static_cast<int64_t>(std::round(x));
|
||||
int64_t y_nearest = static_cast<int64_t>(std::round(y));
|
||||
int64_t z_nearest = static_cast<int64_t>(std::round(z));
|
||||
|
@ -281,10 +281,10 @@ T GridSampler3DGradCpuKernelMod::grid_sampler_compute_source_index_set_grad(T co
|
|||
*grad_x = static_cast<T>(size) / kTwo;
|
||||
coord = ((coord + kOne) * size - kOne) / kTwo;
|
||||
}
|
||||
if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::BORDER)) {
|
||||
if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::BORDER)) {
|
||||
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
||||
*grad_x = (*grad_x) * grad_clip;
|
||||
} else if (padding_mode == static_cast<int64_t>(MsPyEnum::GridSamplerPaddingMode::REFLECTION)) {
|
||||
} else if (padding_mode == static_cast<int64_t>(ops::GridSamplerPaddingMode::REFLECTION)) {
|
||||
if (align_corners) {
|
||||
coord = reflect_coordinates_set_grad(coord, 0, (size - 1) * static_cast<int64_t>(kTwo), &grad_refl);
|
||||
} else {
|
||||
|
|
|
@ -48,7 +48,7 @@ int NLLLossCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs, const
|
|||
return ret;
|
||||
}
|
||||
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
|
||||
reduction_type_ = static_cast<MsPyEnum::Reduction>(reduction);
|
||||
reduction_type_ = static_cast<Reduction>(reduction);
|
||||
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
|
||||
|
||||
auto logits_shape = inputs[kIndex0]->GetShapeVector();
|
||||
|
@ -92,15 +92,15 @@ bool NLLLossCpuKernelMod::LaunchKernel(const std::vector<kernel::KernelTensor *>
|
|||
float n_loss = -logits[index] * n_weight;
|
||||
tmp_total_weight += n_weight;
|
||||
total_loss += n_loss;
|
||||
if (reduction_type_ == MsPyEnum::Reduction::NONE) {
|
||||
if (reduction_type_ == Reduction::NONE) {
|
||||
loss[i] = n_loss;
|
||||
}
|
||||
}
|
||||
|
||||
*total_weight = tmp_total_weight;
|
||||
if (reduction_type_ == MsPyEnum::Reduction::SUM) {
|
||||
if (reduction_type_ == Reduction::REDUCTION_SUM) {
|
||||
*loss = total_loss;
|
||||
} else if (reduction_type_ == MsPyEnum::Reduction::MEAN) {
|
||||
} else if (reduction_type_ == Reduction::MEAN) {
|
||||
*loss = total_loss / tmp_total_weight;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "nnacl/kernel/nllloss.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -55,7 +55,7 @@ class NLLLossCpuKernelMod : public NativeCpuKernelMod {
|
|||
NLLLossFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, NLLLossFunc>> func_list_;
|
||||
NLLLossStruct nllloss_param_{};
|
||||
MsPyEnum::Reduction reduction_type_;
|
||||
Reduction reduction_type_;
|
||||
int64_t ignore_index_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -49,7 +49,7 @@ int NLLLossGradCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
|
|||
return ret;
|
||||
}
|
||||
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
|
||||
reduction_type_ = static_cast<MsPyEnum::Reduction>(reduction);
|
||||
reduction_type_ = static_cast<Reduction>(reduction);
|
||||
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
|
||||
auto logits_shape = inputs[0]->GetShapeVector();
|
||||
nllloss_param_.batch_ = LongToInt(logits_shape[0]);
|
||||
|
@ -89,9 +89,9 @@ bool NLLLossGradCpuKernelMod::LaunchKernel(const std::vector<kernel::KernelTenso
|
|||
}
|
||||
int index = i * nllloss_param_.class_num_ + labels[i];
|
||||
float n_weight = weight[labels[i]];
|
||||
if (reduction_type_ == MsPyEnum::Reduction::SUM) {
|
||||
if (reduction_type_ == Reduction::REDUCTION_SUM) {
|
||||
logits_grad[index] = -loss_grad[0] * n_weight;
|
||||
} else if (reduction_type_ == MsPyEnum::Reduction::MEAN) {
|
||||
} else if (reduction_type_ == Reduction::MEAN) {
|
||||
logits_grad[index] = -loss_grad[0] * n_weight / *total_weight;
|
||||
} else {
|
||||
logits_grad[index] = -loss_grad[i] * n_weight;
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "nnacl/kernel/nllloss.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -57,7 +57,7 @@ class NLLLossGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
private:
|
||||
NLLLossStruct nllloss_param_{};
|
||||
MsPyEnum::Reduction reduction_type_;
|
||||
Reduction reduction_type_;
|
||||
int64_t ignore_index_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -114,10 +114,10 @@ const std::vector<std::pair<KernelAttr, ResizeLinear1DCpuKernelMod::KernelRunFun
|
|||
template <typename T>
|
||||
ResizeLinear1DCpuKernelMod::CoordinateTransformationFunc<T>
|
||||
ResizeLinear1DCpuKernelMod::ChooseCoordinateTransformationFunc(
|
||||
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) const {
|
||||
const std::unordered_map<MsPyEnum::CoordinateTransformationMode, CoordinateTransformationFunc<T>> coordinate_map{
|
||||
{MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
|
||||
{MsPyEnum::CoordinateTransformationMode::HALF_PIXEL, HalfPixelFunc<T>()}};
|
||||
CoordinateTransformMode coordinate_transformation_mode) const {
|
||||
const std::unordered_map<CoordinateTransformMode, CoordinateTransformationFunc<T>> coordinate_map{
|
||||
{CoordinateTransformMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
|
||||
{CoordinateTransformMode::HALF_PIXEL, HalfPixelFunc<T>()}};
|
||||
return coordinate_map.at(coordinate_transformation_mode);
|
||||
}
|
||||
|
||||
|
@ -162,9 +162,9 @@ int ResizeLinear1DCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs
|
|||
out_width_ = LongToSize(output_shape[kIndex2]);
|
||||
|
||||
coordinate_transformation_mode_ =
|
||||
static_cast<MsPyEnum::CoordinateTransformationMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
|
||||
if (coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS &&
|
||||
coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::HALF_PIXEL) {
|
||||
static_cast<CoordinateTransformMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
|
||||
if (coordinate_transformation_mode_ != CoordinateTransformMode::ALIGN_CORNERS &&
|
||||
coordinate_transformation_mode_ != CoordinateTransformMode::HALF_PIXEL) {
|
||||
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";
|
||||
}
|
||||
SetWorkSpaceSize(inputs);
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
|
@ -60,7 +60,7 @@ class ResizeLinear1DCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
|
|||
|
||||
template <typename T>
|
||||
CoordinateTransformationFunc<T> ChooseCoordinateTransformationFunc(
|
||||
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) const;
|
||||
CoordinateTransformMode coordinate_transformation_mode) const;
|
||||
|
||||
template <typename T>
|
||||
void ComputeInterpolationCaches(const size_t out_size, const size_t in_size,
|
||||
|
@ -73,8 +73,7 @@ class ResizeLinear1DCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
|
|||
size_t channel_{0};
|
||||
size_t in_width_{0};
|
||||
size_t out_width_{0};
|
||||
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode_ =
|
||||
MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS;
|
||||
CoordinateTransformMode coordinate_transformation_mode_ = CoordinateTransformMode::ALIGN_CORNERS;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -116,10 +116,10 @@ const std::vector<std::pair<KernelAttr, ResizeLinear1DGradCpuKernelMod::KernelRu
|
|||
template <typename T>
|
||||
ResizeLinear1DGradCpuKernelMod::CoordinateTransformationFunc<T>
|
||||
ResizeLinear1DGradCpuKernelMod::ChooseCoordinateTransformationFunc(
|
||||
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode) {
|
||||
const std::unordered_map<MsPyEnum::CoordinateTransformationMode, CoordinateTransformationFunc<T>> coordinate_map{
|
||||
{MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
|
||||
{MsPyEnum::CoordinateTransformationMode::HALF_PIXEL, HalfPixelFunc<T>()}};
|
||||
CoordinateTransformMode coordinate_transformation_mode) {
|
||||
const std::unordered_map<CoordinateTransformMode, CoordinateTransformationFunc<T>> coordinate_map{
|
||||
{CoordinateTransformMode::ALIGN_CORNERS, AlignCornersFunc<T>()},
|
||||
{CoordinateTransformMode::HALF_PIXEL, HalfPixelFunc<T>()}};
|
||||
return coordinate_map.at(coordinate_transformation_mode);
|
||||
}
|
||||
|
||||
|
@ -164,9 +164,9 @@ int ResizeLinear1DGradCpuKernelMod::Resize(const std::vector<KernelTensor *> &in
|
|||
input_width_ = LongToSize(shape_[kIndex2]);
|
||||
|
||||
coordinate_transformation_mode_ =
|
||||
static_cast<MsPyEnum::CoordinateTransformationMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
|
||||
if (coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS &&
|
||||
coordinate_transformation_mode_ != MsPyEnum::CoordinateTransformationMode::HALF_PIXEL) {
|
||||
static_cast<CoordinateTransformMode>(inputs.at(kIndex2)->GetValueWithCheck<int64_t>());
|
||||
if (coordinate_transformation_mode_ != CoordinateTransformMode::ALIGN_CORNERS &&
|
||||
coordinate_transformation_mode_ != CoordinateTransformMode::HALF_PIXEL) {
|
||||
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";
|
||||
}
|
||||
SetWorkSpaceSize(inputs);
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
|
@ -66,7 +66,7 @@ class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod,
|
|||
|
||||
template <typename T>
|
||||
CoordinateTransformationFunc<T> ChooseCoordinateTransformationFunc(
|
||||
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode);
|
||||
CoordinateTransformMode coordinate_transformation_mode);
|
||||
|
||||
std::string kernel_type_{kUnknown};
|
||||
bool align_corners_{false};
|
||||
|
@ -76,8 +76,7 @@ class ResizeLinear1DGradCpuKernelMod : public NativeCpuKernelMod,
|
|||
size_t channel_{0};
|
||||
size_t input_width_{0};
|
||||
size_t output_width_{0};
|
||||
MsPyEnum::CoordinateTransformationMode coordinate_transformation_mode_ =
|
||||
MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS;
|
||||
CoordinateTransformMode coordinate_transformation_mode_ = CoordinateTransformMode::ALIGN_CORNERS;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -19,12 +19,18 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
enum class ReductionMode { kNone, kMean, kSum };
|
||||
|
||||
static std::map<std::string, ReductionMode> kReductionModeMap{
|
||||
{"none", ReductionMode::kNone}, {"mean", ReductionMode::kMean}, {"sum", ReductionMode::kSum}};
|
||||
|
||||
static std::map<int64_t, ReductionMode> kEnumReductionModeMap{
|
||||
{static_cast<int64_t>(mindspore::Reduction::NONE), ReductionMode::kNone},
|
||||
{static_cast<int64_t>(mindspore::Reduction::MEAN), ReductionMode::kMean},
|
||||
{static_cast<int64_t>(mindspore::Reduction::REDUCTION_SUM), ReductionMode::kSum}};
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT cudaError_t BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction,
|
||||
const T *input_x, const T *input_y, const T *weight, T *loss,
|
||||
|
|
|
@ -55,21 +55,21 @@ enum class FFTNormMode {
|
|||
by_root_n, // same as above, but sqrt the product
|
||||
};
|
||||
|
||||
FFTNormMode GetNormModeFromString(const MsPyEnum::NormMode &norm_type, const bool is_inverse) {
|
||||
if (norm_type == MsPyEnum::NormMode::FORWARD) {
|
||||
FFTNormMode GetNormModeFromString(const ops::NormMode &norm_type, const bool is_inverse) {
|
||||
if (norm_type == ops::NormMode::FORWARD) {
|
||||
return is_inverse ? FFTNormMode::none : FFTNormMode::by_n;
|
||||
}
|
||||
if (norm_type == MsPyEnum::NormMode::BACKWARD) {
|
||||
if (norm_type == ops::NormMode::BACKWARD) {
|
||||
return is_inverse ? FFTNormMode::by_n : FFTNormMode::none;
|
||||
}
|
||||
if (norm_type == MsPyEnum::NormMode::ORTHO) {
|
||||
if (norm_type == ops::NormMode::ORTHO) {
|
||||
return FFTNormMode::by_root_n;
|
||||
}
|
||||
MS_LOG(ERROR) << "For 'FFTWithSize', the fft norm type " << norm_type << " is unsupported!";
|
||||
return FFTNormMode::none;
|
||||
}
|
||||
|
||||
double GetNormScale(const MsPyEnum::NormMode &norm_type, const bool is_inverse, const int n) {
|
||||
double GetNormScale(const ops::NormMode &norm_type, const bool is_inverse, const int n) {
|
||||
FFTNormMode norm_mode = GetNormModeFromString(norm_type, is_inverse);
|
||||
if (norm_mode == FFTNormMode::none) {
|
||||
return 1.0;
|
||||
|
@ -115,7 +115,7 @@ bool FFTWithSizeGpuKernelMod::FFTVarietyInResize(const std::vector<KernelTensor
|
|||
rank_ = inputs[kIndex1]->GetValueWithCheck<int64_t>();
|
||||
is_inverse_ = inputs[kIndex2]->GetValueWithCheck<bool>();
|
||||
is_real_ = inputs[kIndex3]->GetValueWithCheck<bool>();
|
||||
norm_type_ = static_cast<MsPyEnum::NormMode>(inputs[kIndex4]->GetValueWithCheck<int64_t>());
|
||||
norm_type_ = static_cast<ops::NormMode>(inputs[kIndex4]->GetValueWithCheck<int64_t>());
|
||||
is_onesided_ = inputs[kIndex5]->GetValueWithCheck<bool>();
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
#include "ops/op_enum.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fft_with_size_impl.cuh"
|
||||
|
@ -89,7 +89,7 @@ class FFTWithSizeGpuKernelMod : public NativeGpuKernelMod {
|
|||
int64_t rank_{1};
|
||||
bool is_inverse_{false}; // 0: forward, 1: inverse
|
||||
bool is_real_{false};
|
||||
MsPyEnum::NormMode norm_type_{MsPyEnum::NormMode::BACKWARD}; // forward, backward, ortho
|
||||
ops::NormMode norm_type_{ops::NormMode::BACKWARD}; // forward, backward, ortho
|
||||
// is_onesided controls whether frequency is halved when signal is real, which means is_real_ is true.
|
||||
// The default value is true. cufft does not support full freq with real signal. We use cast as a temporary solution.
|
||||
bool is_onesided_{true};
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include <functional>
|
||||
#include "mindspore/core/ops/ops_func_impl/grid_sampler_2d.h"
|
||||
#include "mindspore/core/ops/ops_func_impl/grid_sampler_3d.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindspore/core/ops/op_enum.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_impl.cuh"
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <algorithm>
|
||||
#include "mindspore/core/ops/ops_func_impl/grid_sampler_2d_grad.h"
|
||||
#include "mindspore/core/ops/ops_func_impl/grid_sampler_3d_grad.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindspore/core/ops/op_enum.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/grid_sampler_grad_impl.cuh"
|
||||
|
|
|
@ -42,7 +42,7 @@ int NLLLossGpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs, const
|
|||
return ret;
|
||||
}
|
||||
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
|
||||
reduction_ = static_cast<ReductionMode>(reduction);
|
||||
reduction_ = kEnumReductionModeMap[reduction];
|
||||
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
|
||||
auto logits_shape = inputs[kIndex0]->GetShapeVector();
|
||||
label_size_ = logits_shape[0];
|
||||
|
|
|
@ -45,7 +45,7 @@ int NLLLossGradGpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
|
|||
return ret;
|
||||
}
|
||||
auto reduction = inputs[kReductionIdx]->GetValueWithCheck<int64_t>();
|
||||
reduction_ = static_cast<ReductionMode>(reduction);
|
||||
reduction_ = kEnumReductionModeMap[reduction];
|
||||
ignore_index_ = inputs[kIgnoreIndexIdx]->GetValueWithCheck<int64_t>();
|
||||
|
||||
auto logits_shape = inputs[kIndex0]->GetShapeVector();
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "plugin/device/gpu/kernel/nn/resize_linear_1d_gpu_kernel.h"
|
||||
#include "mindspore/core/abstract/utils.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace {
|
||||
constexpr const size_t kResizeLinear1DInputsNum = 3;
|
||||
|
@ -58,10 +58,9 @@ int ResizeLinear1DGpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs
|
|||
out_width_ = output_shape_[kIndex2];
|
||||
|
||||
auto coordinate_transformation_mode = inputs.at(kIndex2)->GetValueWithCheck<int64_t>();
|
||||
if (coordinate_transformation_mode == static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS)) {
|
||||
if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::ALIGN_CORNERS)) {
|
||||
mode_ = ResizeLinearCoordinateTransformationMode::ALIGN_CORNERS;
|
||||
} else if (coordinate_transformation_mode ==
|
||||
static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::HALF_PIXEL)) {
|
||||
} else if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::HALF_PIXEL)) {
|
||||
mode_ = ResizeLinearCoordinateTransformationMode::HALF_PIXEL;
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "plugin/device/gpu/kernel/nn/resize_linear_1d_grad_gpu_kernel.h"
|
||||
#include "mindspore/core/abstract/utils.h"
|
||||
#include "ops/ops_func_impl/resize_linear_1d_grad.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace {
|
||||
constexpr const size_t kResizeLinear1DGradInputsNum = 3;
|
||||
|
@ -64,10 +64,9 @@ int ResizeLinear1DGradGpuKernelMod::Resize(const std::vector<KernelTensor *> &in
|
|||
workspace_size_list_.push_back(work_space_size * sizeof(float));
|
||||
|
||||
auto coordinate_transformation_mode = inputs.at(kIndex2)->GetValueWithCheck<int64_t>();
|
||||
if (coordinate_transformation_mode == static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::ALIGN_CORNERS)) {
|
||||
if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::ALIGN_CORNERS)) {
|
||||
mode_ = ResizeLinearCoordinateTransformationMode::ALIGN_CORNERS;
|
||||
} else if (coordinate_transformation_mode ==
|
||||
static_cast<int64_t>(MsPyEnum::CoordinateTransformationMode::HALF_PIXEL)) {
|
||||
} else if (coordinate_transformation_mode == static_cast<int64_t>(CoordinateTransformMode::HALF_PIXEL)) {
|
||||
mode_ = ResizeLinearCoordinateTransformationMode::HALF_PIXEL;
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', coordinate_transformation_mode not support now.";
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <string>
|
||||
#include "mindspore/core/ops/nn_ops.h"
|
||||
#include "mindspore/core/ops/math_ops.h"
|
||||
#include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindspore/core/ops/op_utils.h"
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mindspore/core/ops/op_enum.h"
|
||||
#include "mindspore/core/ops/op_def.h"
|
||||
#include "mindspore/core/mindapi/base/format.h"
|
||||
#include "include/common/pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
void RegOpEnum(py::module *m) {
|
||||
auto m_sub = m->def_submodule("op_enum", "submodule for op enum");
|
||||
(void)m_sub.def("str_to_enum", &StringToEnumImpl, "string to enum value");
|
||||
(void)py::enum_<OP_DTYPE>(*m, "OpDtype", py::arithmetic())
|
||||
.value("DT_BEGIN", OP_DTYPE::DT_BEGIN)
|
||||
.value("DT_BOOL", OP_DTYPE::DT_BOOL)
|
||||
.value("DT_INT", OP_DTYPE::DT_INT)
|
||||
.value("DT_FLOAT", OP_DTYPE::DT_FLOAT)
|
||||
.value("DT_NUMBER", OP_DTYPE::DT_NUMBER)
|
||||
.value("DT_TENSOR", OP_DTYPE::DT_TENSOR)
|
||||
.value("DT_STR", OP_DTYPE::DT_STR)
|
||||
.value("DT_ANY", OP_DTYPE::DT_ANY)
|
||||
.value("DT_TUPLE_BOOL", OP_DTYPE::DT_TUPLE_BOOL)
|
||||
.value("DT_TUPLE_INT", OP_DTYPE::DT_TUPLE_INT)
|
||||
.value("DT_TUPLE_FLOAT", OP_DTYPE::DT_TUPLE_FLOAT)
|
||||
.value("DT_TUPLE_NUMBER", OP_DTYPE::DT_TUPLE_NUMBER)
|
||||
.value("DT_TUPLE_TENSOR", OP_DTYPE::DT_TUPLE_TENSOR)
|
||||
.value("DT_TUPLE_STR", OP_DTYPE::DT_TUPLE_STR)
|
||||
.value("DT_TUPLE_ANY", OP_DTYPE::DT_TUPLE_ANY)
|
||||
.value("DT_LIST_BOOL", OP_DTYPE::DT_LIST_BOOL)
|
||||
.value("DT_LIST_INT", OP_DTYPE::DT_LIST_INT)
|
||||
.value("DT_LIST_FLOAT", OP_DTYPE::DT_LIST_FLOAT)
|
||||
.value("DT_LIST_NUMBER", OP_DTYPE::DT_LIST_NUMBER)
|
||||
.value("DT_LIST_TENSOR", OP_DTYPE::DT_LIST_TENSOR)
|
||||
.value("DT_LIST_STR", OP_DTYPE::DT_LIST_STR)
|
||||
.value("DT_LIST_ANY", OP_DTYPE::DT_LIST_ANY)
|
||||
.value("DT_TYPE", OP_DTYPE::DT_TYPE)
|
||||
.value("DT_END", OP_DTYPE::DT_END);
|
||||
// There are currently some deficiencies in format, which will be filled in later.
|
||||
(void)py::enum_<Format>(*m, "FormatEnum", py::arithmetic())
|
||||
.value("DEFAULT_FORMAT", Format::DEFAULT_FORMAT)
|
||||
.value("NCHW", Format::NCHW)
|
||||
.value("NHWC", Format::NHWC)
|
||||
.value("NHWC4", Format::NHWC4)
|
||||
.value("HWKC", Format::HWKC)
|
||||
.value("HWCK", Format::HWCK)
|
||||
.value("KCHW", Format::KCHW)
|
||||
.value("CKHW", Format::CKHW)
|
||||
.value("KHWC", Format::KHWC)
|
||||
.value("CHWK", Format::CHWK)
|
||||
.value("HW", Format::HW)
|
||||
.value("HW4", Format::HW4)
|
||||
.value("NC", Format::NC)
|
||||
.value("NC4", Format::NC4)
|
||||
.value("NC4HW4", Format::NC4HW4)
|
||||
.value("NCDHW", Format::NCDHW)
|
||||
.value("NWC", Format::NWC)
|
||||
.value("NCW", Format::NCW)
|
||||
.value("NDHWC", Format::NDHWC)
|
||||
.value("NC8HW8", Format::NC8HW8);
|
||||
}
|
||||
} // namespace mindspore::ops
|
|
@ -238,7 +238,7 @@ class GEPadMod {
|
|||
class GEReduction {
|
||||
public:
|
||||
static std::string ConvertEnumToString(int64_t id) {
|
||||
static const std::vector<std::string> reductions = {"none", "mean", "sum", "add"};
|
||||
static const std::vector<std::string> reductions = {"sum", "mean", "none"};
|
||||
if (id < 0 || id >= static_cast<int64_t>(reductions.size())) {
|
||||
MS_LOG(EXCEPTION) << "Invalid reduction " << id;
|
||||
return "";
|
||||
|
|
|
@ -86,7 +86,7 @@ enum Format : int64_t {
|
|||
NUM_OF_FORMAT
|
||||
};
|
||||
|
||||
inline std::string FormatEnumToString(mindspore::Format format) {
|
||||
inline const std::vector<std::string> &GetFormatNames() {
|
||||
static std::vector<std::string> names = {
|
||||
"NCHW",
|
||||
"NHWC",
|
||||
|
@ -148,6 +148,11 @@ inline std::string FormatEnumToString(mindspore::Format format) {
|
|||
"NYUV_A",
|
||||
"NCL",
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline std::string FormatEnumToString(mindspore::Format format) {
|
||||
const auto &names = GetFormatNames();
|
||||
if (format == mindspore::Format::DEFAULT_FORMAT) {
|
||||
return "DefaultFormat";
|
||||
}
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
|
||||
#define MINDSPORE_CORE_MINDAPI_BASE_TYPE_ID_H_
|
||||
// TODO(chenfei): remove the TypeId by include the auto generated type id.
|
||||
// #include "mindspore/core/ops/auto_generate/gen_enum_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
/// \brief TypeId defines data type identifiers.
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/op_enum.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "mindapi/base/types.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/base/format.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
namespace {
|
||||
using StrToEnumMap = std::unordered_map<std::string, int64_t>;
|
||||
|
||||
class RegStringToEnumHelper {
|
||||
public:
|
||||
template <typename T>
|
||||
std::string AddValues(T &&string_to_enum) {
|
||||
for (const auto &kv : string_to_enum) {
|
||||
if (string_to_enum_.find(kv.first) != string_to_enum_.end()) {
|
||||
MS_LOG_EXCEPTION << kv.first << " has been registered!";
|
||||
}
|
||||
}
|
||||
string_to_enum_.merge(std::move(string_to_enum));
|
||||
return "";
|
||||
}
|
||||
|
||||
const StrToEnumMap &GetValues() { return string_to_enum_; }
|
||||
|
||||
private:
|
||||
StrToEnumMap string_to_enum_;
|
||||
};
|
||||
RegStringToEnumHelper reg_string_to_enum_helper;
|
||||
|
||||
#define REG_STRING_TO_ENUM(enum_type, ...) \
|
||||
const auto op_enum_##enum_type = reg_string_to_enum_helper.AddValues(__VA_ARGS__);
|
||||
|
||||
// Convert to uppercase uniformly
|
||||
inline std::string StrToUpper(const std::string &str) {
|
||||
auto res = str;
|
||||
for (auto &c : res) {
|
||||
c = std::toupper(c);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Format
|
||||
inline std::unordered_map<std::string, int64_t> GetStringToFormatMap() {
|
||||
const auto &names = GetFormatNames();
|
||||
std::unordered_map<std::string, int64_t> map{{"DEFAULT_FORMAT", static_cast<int64_t>(Format::DEFAULT_FORMAT)}};
|
||||
for (size_t i = 0; i < names.size(); ++i) {
|
||||
map[StrToUpper(names[i])] = static_cast<int64_t>(i);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
REG_STRING_TO_ENUM(format, GetStringToFormatMap())
|
||||
|
||||
// PadMode
|
||||
StrToEnumMap StrToPadModeMap = {{"PAD", PadMode::PAD}, {"SAME", PadMode::SAME}, {"VALID", PadMode::VALID}};
|
||||
REG_STRING_TO_ENUM(pad_mode, StrToPadModeMap)
|
||||
|
||||
// Reduction
|
||||
StrToEnumMap StrToReductionMap = {
|
||||
{"SUM", Reduction::REDUCTION_SUM}, {"MEAN", Reduction::MEAN}, {"NONE", Reduction::NONE}};
|
||||
REG_STRING_TO_ENUM(reduction, StrToReductionMap)
|
||||
|
||||
// Activation
|
||||
StrToEnumMap StrToActivationMap = {{"NO_ACTIVATION", ActivationType::NO_ACTIVATION},
|
||||
{"RELU", ActivationType::RELU},
|
||||
{"SIGMOID", ActivationType::SIGMOID},
|
||||
{"RELU6", ActivationType::RELU6},
|
||||
{"ELU", ActivationType::ELU},
|
||||
{"LEAKY_RELU", ActivationType::LEAKY_RELU},
|
||||
{"ABS", ActivationType::ABS},
|
||||
{"RELU1", ActivationType::RELU1},
|
||||
{"SOFTSIGN", ActivationType::SOFTSIGN},
|
||||
{"SOFTPLUS", ActivationType::SOFTPLUS},
|
||||
{"TANH", ActivationType::TANH},
|
||||
{"SELU", ActivationType::SELU},
|
||||
{"HSWISH", ActivationType::HSWISH},
|
||||
{"HSIGMOID", ActivationType::HSIGMOID},
|
||||
{"THRESHOLDRELU", ActivationType::THRESHOLDRELU},
|
||||
{"LINEAR", ActivationType::LINEAR},
|
||||
{"HARD_TANH", ActivationType::HARD_TANH},
|
||||
{"SIGN", ActivationType::SIGN},
|
||||
{"SWISH", ActivationType::SWISH},
|
||||
{"GELU", ActivationType::GELU},
|
||||
{"GLU", ActivationType::GLU},
|
||||
{"UNKNOWN", ActivationType::UNKNOWN}};
|
||||
REG_STRING_TO_ENUM(activation, StrToActivationMap)
|
||||
|
||||
// GateOrder
|
||||
REG_STRING_TO_ENUM(gate_order, StrToEnumMap{{"RZH", GateOrderMode::RZH}, {"ZRH", GateOrderMode::ZRH}})
|
||||
|
||||
// CoordinateTransformationMode
|
||||
StrToEnumMap StrToCoordinateTransformationModeMap = {{"ASYMMETRIC", CoordinateTransformMode::ASYMMETRIC},
|
||||
{"ALIGN_CORNERS", CoordinateTransformMode::ALIGN_CORNERS},
|
||||
{"HALF_PIXEL", CoordinateTransformMode::HALF_PIXEL},
|
||||
{"CROP_AND_RESIZE", CoordinateTransformMode::CROP_AND_RESIZE}};
|
||||
REG_STRING_TO_ENUM(coordinate_transformation_mode, StrToCoordinateTransformationModeMap)
|
||||
|
||||
// PaddingMode
|
||||
StrToEnumMap StrToPaddingModeMap = {{"CONSTANT", PaddingMode::CONSTANT},
|
||||
{"REFLECT", PaddingMode::REFLECT},
|
||||
{"SYMMETRIC", PaddingMode::SYMMETRIC},
|
||||
{"MODE_RESERVED", PaddingMode::MODE_RESERVED}};
|
||||
REG_STRING_TO_ENUM(padding_mode, StrToPaddingModeMap)
|
||||
|
||||
// Direction
|
||||
REG_STRING_TO_ENUM(direction, StrToEnumMap{{"UNIDIRECTIONAL", Direction::UNIDIRECTIONAL}})
|
||||
|
||||
// CellType
|
||||
REG_STRING_TO_ENUM(cell_type, StrToEnumMap{{"LSTM", CellType::LSTM}})
|
||||
|
||||
// Group
|
||||
REG_STRING_TO_ENUM(group, StrToEnumMap{{"SYNC_BN_GROUP0", Group::SYNC_BN_GROUP0}})
|
||||
|
||||
// InterpolationMode
|
||||
REG_STRING_TO_ENUM(interpolation_mode,
|
||||
StrToEnumMap{{"BILINEAR", InterpolationMode::BILINEAR}, {"NEAREST", InterpolationMode::NEAREST}})
|
||||
|
||||
// NormMode
|
||||
StrToEnumMap StrToNormModeMap = {
|
||||
{"BACKWARD", NormMode::BACKWARD}, {"FORWARD", NormMode::FORWARD}, {"ORTHO", NormMode::ORTHO}};
|
||||
REG_STRING_TO_ENUM(norm_mode, StrToNormModeMap)
|
||||
|
||||
// GridSamplerPaddingMode
|
||||
StrToEnumMap StrToGridSamplerPaddingMode = {{"ZEROS", GridSamplerPaddingMode::ZEROS},
|
||||
{"BORDER", GridSamplerPaddingMode::BORDER},
|
||||
{"REFLECTION", GridSamplerPaddingMode::REFLECTION}};
|
||||
REG_STRING_TO_ENUM(grid_sampler_padding_mode, StrToGridSamplerPaddingMode)
|
||||
|
||||
// KVCacheAlignMode
|
||||
REG_STRING_TO_ENUM(k_v_cache_align_mode,
|
||||
StrToEnumMap{{"LEFT", KVCacheAlignMode::LEFT}, {"RIGHT", KVCacheAlignMode::RIGHT}})
|
||||
|
||||
} // namespace
|
||||
|
||||
int64_t StringToEnumImpl(const std::string &enum_string) {
|
||||
const auto &string_to_enum_map = reg_string_to_enum_helper.GetValues();
|
||||
const auto enum_val_iter = string_to_enum_map.find(StrToUpper(enum_string));
|
||||
if (enum_val_iter == string_to_enum_map.end()) {
|
||||
MS_LOG_EXCEPTION << "Can not find '" << enum_string << "', please add it";
|
||||
}
|
||||
return enum_val_iter->second;
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_OP_ENUM_H_
|
||||
#define MINDSPORE_CORE_OPS_OP_ENUM_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mindapi/base/macros.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API int64_t StringToEnumImpl(const std::string &enum_string);
|
||||
|
||||
// Only the current mindspore/core/mindapi/base/types.h and other files do not have
|
||||
// corresponding enumerations and then add new enumerations.
|
||||
// The `enum` is used here instead of `enum class` because the current backend enum is
|
||||
// represented by `int`. The `enum` is more convenient than `enum class` compare with int.
|
||||
enum Direction : int64_t { UNIDIRECTIONAL = 0 };
|
||||
|
||||
enum CellType : int64_t { LSTM = 0 };
|
||||
|
||||
enum Group : int64_t { SYNC_BN_GROUP0 = 0 };
|
||||
|
||||
enum InterpolationMode : int64_t { BILINEAR = 0, NEAREST = 1 };
|
||||
|
||||
enum NormMode : int64_t { BACKWARD = 0, FORWARD = 1, ORTHO = 2 };
|
||||
|
||||
enum GridSamplerPaddingMode : int64_t { ZEROS = 0, BORDER = 1, REFLECTION = 2 };
|
||||
|
||||
enum KVCacheAlignMode : int64_t { RIGHT = 0, LEFT = 1 };
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_ENUM_H_
|
|
@ -8,10 +8,10 @@ argmax:
|
|||
default: -1
|
||||
prim_init: True
|
||||
output_type:
|
||||
dtype: int
|
||||
dtype: TypeId
|
||||
default: mstype.int32
|
||||
prim_init: True
|
||||
arg_handler: dtype_to_enum
|
||||
arg_handler: dtype_to_type_id
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -8,10 +8,10 @@ argmin:
|
|||
default: -1
|
||||
prim_init: True
|
||||
output_type:
|
||||
dtype: int
|
||||
dtype: TypeId
|
||||
default: mstype.int32
|
||||
prim_init: True
|
||||
arg_handler: dtype_to_enum
|
||||
arg_handler: dtype_to_type_id
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -21,12 +21,12 @@ avg_pool_grad:
|
|||
dtype: int
|
||||
default: "'VALID'"
|
||||
prim_init: True
|
||||
arg_handler: pad_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
data_format:
|
||||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -17,12 +17,12 @@ avg_pool:
|
|||
dtype: int
|
||||
default: "'VALID'"
|
||||
prim_init: True
|
||||
arg_handler: pad_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
data_format:
|
||||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -28,7 +28,7 @@ batch_norm_grad_grad:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
dx:
|
||||
dtype: tensor
|
||||
|
|
|
@ -24,7 +24,7 @@ batch_norm_grad:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
dx:
|
||||
dtype: tensor
|
||||
|
|
|
@ -28,7 +28,7 @@ batch_norm_grad_with_activation:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
dx:
|
||||
dtype: tensor
|
||||
|
|
|
@ -28,7 +28,7 @@ batch_norm_grad_with_activation:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
dx:
|
||||
dtype: tensor
|
||||
|
|
|
@ -26,7 +26,7 @@ batch_norm:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output_x:
|
||||
dtype: tensor
|
||||
|
|
|
@ -26,7 +26,7 @@ batch_norm_with_activation:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output_x:
|
||||
dtype: tensor
|
||||
|
|
|
@ -28,7 +28,7 @@ batch_norm_with_add_and_activation:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output_x:
|
||||
dtype: tensor
|
||||
|
|
|
@ -7,7 +7,7 @@ bias_add_grad:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -9,7 +9,7 @@ bias_add:
|
|||
dtype: int
|
||||
default: "'NCHW'"
|
||||
prim_init: True
|
||||
arg_handler: py_format_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -6,8 +6,8 @@ eye:
|
|||
m:
|
||||
dtype: int
|
||||
dtype:
|
||||
dtype: int
|
||||
arg_handler: dtype_to_enum
|
||||
dtype: TypeId
|
||||
arg_handler: dtype_to_type_id
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -14,7 +14,7 @@ fft_with_size:
|
|||
prim_init: True
|
||||
norm:
|
||||
dtype: int
|
||||
arg_handler: norm_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
default: "'backward'"
|
||||
prim_init: True
|
||||
onesided:
|
||||
|
|
|
@ -11,12 +11,12 @@ grid_sampler_2d_grad:
|
|||
dtype: int
|
||||
default: "'bilinear'"
|
||||
prim_init: True
|
||||
arg_handler: interpolation_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
padding_mode:
|
||||
dtype: int
|
||||
default: "'zeros'"
|
||||
prim_init: True
|
||||
arg_handler: grid_sampler_padding_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
align_corners:
|
||||
dtype: bool
|
||||
default: False
|
||||
|
|
|
@ -9,12 +9,12 @@ grid_sampler_2d:
|
|||
dtype: int
|
||||
default: "'bilinear'"
|
||||
prim_init: True
|
||||
arg_handler: interpolation_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
padding_mode:
|
||||
dtype: int
|
||||
default: "'zeros'"
|
||||
prim_init: True
|
||||
arg_handler: grid_sampler_padding_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
align_corners:
|
||||
dtype: bool
|
||||
default: False
|
||||
|
|
|
@ -11,12 +11,12 @@ grid_sampler_3d_grad:
|
|||
dtype: int
|
||||
default: "'bilinear'"
|
||||
prim_init: True
|
||||
arg_handler: interpolation_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
padding_mode:
|
||||
dtype: int
|
||||
default: "'zeros'"
|
||||
prim_init: True
|
||||
arg_handler: grid_sampler_padding_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
align_corners:
|
||||
dtype: bool
|
||||
default: False
|
||||
|
|
|
@ -9,12 +9,12 @@ grid_sampler_3d:
|
|||
dtype: int
|
||||
default: "'bilinear'"
|
||||
prim_init: True
|
||||
arg_handler: interpolation_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
padding_mode:
|
||||
dtype: int
|
||||
default: "'zeros'"
|
||||
prim_init: True
|
||||
arg_handler: grid_sampler_padding_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
align_corners:
|
||||
dtype: bool
|
||||
default: False
|
||||
|
|
|
@ -22,7 +22,7 @@ extract_image_patches:
|
|||
dtype: int
|
||||
default: "'VALID'"
|
||||
prim_init: True
|
||||
arg_handler: pad_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -19,7 +19,7 @@ prompt_k_v_cache:
|
|||
dtype: int
|
||||
default: 0
|
||||
prim_init: True
|
||||
arg_handler: k_v_cache_align_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
labels:
|
||||
side_effect_mem: True
|
||||
returns:
|
||||
|
|
|
@ -15,7 +15,7 @@ nllloss_grad:
|
|||
dtype: int
|
||||
default: "'mean'"
|
||||
prim_init: True
|
||||
arg_handler: reduction_to_enum
|
||||
arg_handler: str_to_enum
|
||||
ignore_index:
|
||||
dtype: int
|
||||
default: -100
|
||||
|
|
|
@ -11,7 +11,7 @@ nllloss:
|
|||
dtype: int
|
||||
default: "'mean'"
|
||||
prim_init: True
|
||||
arg_handler: reduction_to_enum
|
||||
arg_handler: str_to_enum
|
||||
ignore_index:
|
||||
dtype: int
|
||||
default: -100
|
||||
|
|
|
@ -13,10 +13,10 @@ randperm_v2:
|
|||
default: 0
|
||||
prim_init: True
|
||||
dtype:
|
||||
dtype: int
|
||||
dtype: TypeId
|
||||
default: mstype.int64
|
||||
prim_init: True
|
||||
arg_handler: dtype_to_enum
|
||||
arg_handler: dtype_to_type_id
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -9,7 +9,7 @@ resize_linear_1d_grad:
|
|||
dtype: int
|
||||
default: "'align_corners'"
|
||||
prim_init: True
|
||||
arg_handler: coordinate_transformation_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -10,7 +10,7 @@ resize_linear_1d:
|
|||
dtype: int
|
||||
default: "'align_corners'"
|
||||
prim_init: True
|
||||
arg_handler: coordinate_transformation_mode_to_enum
|
||||
arg_handler: str_to_enum
|
||||
returns:
|
||||
output:
|
||||
dtype: tensor
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kDefaultRank = 2;
|
||||
// As kDefaultRank will have a confilict problem during lite compile, use kDefaultShapeSize.
|
||||
constexpr size_t kDefaultShapeSize = 2;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
constexpr size_t kColIndex = 1;
|
||||
} // namespace
|
||||
|
@ -32,7 +33,7 @@ void EigCheckShapeValid(const ShapeVector &input_shape) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (input_shape.size() < kDefaultRank) {
|
||||
if (input_shape.size() < kDefaultShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "For Eig, x should be at lease 2"
|
||||
<< ", but got a " << input_shape.size() << "-D Tensor.";
|
||||
}
|
||||
|
@ -52,7 +53,7 @@ void EigCheckShapeValid(const ShapeVector &input_shape) {
|
|||
BaseShapePtr EigFuncImpl::InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
auto input_shape = input_args[kInputIndex0]->GetShape()->GetShapeVector();
|
||||
std::vector<BaseShapePtr> shapes_list(kDefaultRank);
|
||||
std::vector<BaseShapePtr> shapes_list(kDefaultShapeSize);
|
||||
EigCheckShapeValid(input_shape);
|
||||
|
||||
/* infer eigen_value shape */
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <unordered_map>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "abstract/dshape.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/op_name.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -86,15 +86,15 @@ BaseShapePtr NLLLossFuncImpl::InferShape(const PrimitivePtr &primitive,
|
|||
std::vector<abstract::BaseShapePtr>{loss_shape_ptr, total_weight_shape_ptr});
|
||||
}
|
||||
|
||||
auto reduce_value_enum = static_cast<MsPyEnum::Reduction>(reduction_opt.value());
|
||||
if ((reduce_value_enum == MsPyEnum::Reduction::SUM) || (reduce_value_enum == MsPyEnum::Reduction::MEAN)) {
|
||||
auto reduce_value_enum = static_cast<Reduction>(reduction_opt.value());
|
||||
if ((reduce_value_enum == Reduction::REDUCTION_SUM) || (reduce_value_enum == Reduction::MEAN)) {
|
||||
// shape () means 0D tensor.
|
||||
auto loss_shape_ptr = std::make_shared<abstract::TensorShape>(loss_shape);
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{loss_shape_ptr, total_weight_shape_ptr});
|
||||
}
|
||||
|
||||
if (reduce_value_enum == MsPyEnum::Reduction::NONE) {
|
||||
if (reduce_value_enum == Reduction::NONE) {
|
||||
if (logits_shape.size() == DIM_2) {
|
||||
loss_shape.push_back(
|
||||
std::max({logits_shape[kInputIndex0], labels_shape[kInputIndex0], abstract::Shape::kShapeDimAny}));
|
||||
|
|
|
@ -103,6 +103,9 @@ GVAR_DEF(PrimitivePtr, kPrimNPUAntiQuant, std::make_shared<Primitive>("AscendAnt
|
|||
|
||||
// Fusion Inference OP
|
||||
GVAR_DEF(PrimitivePtr, kPrimFFN, std::make_shared<Primitive>("FFN"));
|
||||
|
||||
// ToEnum OP
|
||||
GVAR_DEF(PrimitivePtr, kPrimStringToEnum, std::make_shared<Primitive>("StringToEnum"));
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/string_to_enum.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ir/dtype/tensor_type.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/other_ops.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_enum.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
MIND_API_OPERATOR_IMPL(StringToEnum, BaseOperator);
|
||||
class MIND_API StringToEnumInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return std::make_shared<abstract::TensorShape>();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return kInt64;
|
||||
}
|
||||
|
||||
ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const auto &op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("StringToEnum infer", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
const auto &input_value = input_args[0]->GetValue();
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
if (!input_value->isa<StringImm>()) {
|
||||
MS_LOG_EXCEPTION << "Currently, for " << op_name << ", input value must a string value";
|
||||
}
|
||||
const auto &enum_str = GetValue<std::string>(input_value);
|
||||
const auto enum_int = StringToEnumImpl(enum_str);
|
||||
return MakeValue(enum_int);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(StringToEnum, prim::kPrimStringToEnum, StringToEnumInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_
|
||||
#define MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_
|
||||
#include <string>
|
||||
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameStringToEnum = "StringToEnum";
|
||||
|
||||
/// \brief Returns the enum value of the input string.
|
||||
class MIND_API StringToEnum : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(StringToEnum);
|
||||
/// \brief Constructor.
|
||||
StringToEnum() : BaseOperator(kNameStringToEnum) { InitIOName({"x"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.StringToEnum for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_STRING_TO_ENUM_H_
|
|
@ -20,7 +20,6 @@
|
|||
#include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||
#include "ops/auto_generate/gen_lite_ops.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ConvBiasFusionInoutTest : public ConvFusionInoutTest {
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore.ops.primitive import _primexpr
|
|||
from mindspore.ops.function import _VmapGeneralRule
|
||||
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
|
||||
_bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
|
||||
from mindspore.ops.auto_generate.gen_enum_def import PyFormat
|
||||
from mindspore._c_expression import FormatEnum as Format
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.NLLLossGrad)
|
||||
|
@ -306,7 +306,7 @@ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
|
|||
@vmap_rules_getters.register(G.BatchNormGradGrad)
|
||||
def get_batchnorm_grad_grad_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `BatchNormGradGrad` operation."""
|
||||
NCHW = PyFormat.NCHW.value
|
||||
NCHW = Format.NCHW
|
||||
|
||||
def vmap_rule(x_bdim, dy_bdim, scale_bdim, mean_bdim, variance_bdim, dout_dx_bdim,
|
||||
dout_dscale_bdim, dout_dbias_bdim, is_training_bdim, epsilon_bdim, data_format_bdim):
|
||||
|
@ -384,7 +384,7 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
|
|||
bn_min_dim = 3
|
||||
bn_max_dim = 5
|
||||
prim_name = prim.name
|
||||
NHWC = PyFormat.NHWC.value
|
||||
NHWC = Format.NHWC
|
||||
|
||||
def vmap_rule(grad_bdim, x_bdim, scale_bdim, rsv_1_bdim, rsv_2_bdim,
|
||||
rsv_3_bdim, training_bdim, epsilon_bdim, format_bdim):
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_prepr
|
|||
_bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
|
||||
_vmap_clone_prim, _get_reduce_batch_axis
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops.auto_generate.gen_enum_def import PyFormat
|
||||
from mindspore._c_expression import FormatEnum as Format
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.ApplyAdaMax)
|
||||
|
@ -372,7 +372,7 @@ def get_bias_add_vmap_rule(prim, axis_size):
|
|||
|
||||
@constexpr
|
||||
def get_channal_pos_in_x(d_format, n_dims):
|
||||
if d_format == PyFormat.NHWC:
|
||||
if d_format == Format.NHWC:
|
||||
return n_dims
|
||||
return 2
|
||||
|
||||
|
@ -424,7 +424,7 @@ def get_bias_add_grad_vmap_rule(prim, axis_size):
|
|||
"""VmapRule for `BiasAddGrad` operation."""
|
||||
@constexpr
|
||||
def get_channal_pos(d_format, x_rank):
|
||||
if d_format == PyFormat.NHWC:
|
||||
if d_format == Format.NHWC:
|
||||
return x_rank
|
||||
return 2
|
||||
|
||||
|
@ -1249,7 +1249,7 @@ def get_batchnorm_vmap_rule(prim, axis_size):
|
|||
bn_min_dim = 3
|
||||
bn_max_dim = 5
|
||||
prim_name = "BatchNorm"
|
||||
NCHW = PyFormat.NCHW.value
|
||||
NCHW = Format.NCHW
|
||||
|
||||
def vmap_rule(*inputs):
|
||||
is_all_none, result = vmap_general_preprocess(prim, *inputs)
|
||||
|
|
|
@ -19,12 +19,11 @@ Primitive operator classes and operator functional.
|
|||
A collection of operators to build neural networks or to compute functions.
|
||||
"""
|
||||
|
||||
from . import gen_ops_def, gen_arg_handler, gen_enum_def, gen_arg_dtype_cast
|
||||
from . import gen_ops_def, gen_arg_handler, gen_arg_dtype_cast
|
||||
|
||||
from .gen_ops_def import *
|
||||
from .gen_arg_handler import *
|
||||
from .gen_arg_dtype_cast import *
|
||||
from .gen_enum_def import *
|
||||
from ..operations.manually_defined.ops_def import *
|
||||
from .gen_inner_ops_def import *
|
||||
|
||||
|
|
|
@ -572,7 +572,7 @@ class BatchNorm(Primitive):
|
|||
self.is_training = is_training
|
||||
self.epsilon = epsilon
|
||||
self.momentum = momentum
|
||||
self.data_format = handler.py_format_to_enum(data_format)
|
||||
self.data_format = handler.str_to_enum(data_format)
|
||||
|
||||
def __call__(self, *args):
|
||||
return super().__call__(*args, self.is_training, self.epsilon,
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore as ms
|
|||
from mindspore import ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple
|
||||
from mindspore.ops.auto_generate.gen_enum_def import OpDtype
|
||||
from mindspore._c_expression import OpDtype
|
||||
|
||||
tensor_to_tuple_ = TensorToTuple()
|
||||
|
||||
|
@ -67,54 +67,53 @@ def tuple_to_tensor(data):
|
|||
def list_to_tensor(data):
|
||||
return ops.tuple_to_array(list_to_tuple(data))
|
||||
|
||||
# There will be some problems in using OpDtype.xxx directly in GRAPH_MODE, so convert it to int.
|
||||
# type
|
||||
PY_DT_TYPE = OpDtype.PY_DT_TYPE.value
|
||||
|
||||
DT_TYPE_VAL = int(OpDtype.DT_TYPE)
|
||||
# scalar
|
||||
PY_DT_INT = OpDtype.PY_DT_INT.value
|
||||
PY_DT_FLOAT = OpDtype.PY_DT_FLOAT.value
|
||||
PY_DT_BOOL = OpDtype.PY_DT_BOOL.value
|
||||
PY_DT_NUMBER = OpDtype.PY_DT_NUMBER.value
|
||||
DT_INT_VAL = int(OpDtype.DT_INT)
|
||||
DT_FLOAT_VAL = int(OpDtype.DT_FLOAT)
|
||||
DT_BOOL_VAL = int(OpDtype.DT_BOOL)
|
||||
DT_NUMBER_VAL = int(OpDtype.DT_NUMBER)
|
||||
# tuple
|
||||
PY_DT_TUPLE_BOOL = OpDtype.PY_DT_TUPLE_BOOL.value
|
||||
PY_DT_TUPLE_INT = OpDtype.PY_DT_TUPLE_INT.value
|
||||
PY_DT_TUPLE_FLOAT = OpDtype.PY_DT_TUPLE_FLOAT.value
|
||||
PY_DT_TUPLE_NUMBER = OpDtype.PY_DT_TUPLE_NUMBER.value
|
||||
PY_DT_TUPLE_TENSOR = OpDtype.PY_DT_TUPLE_TENSOR.value
|
||||
PY_DT_TUPLE_STR = OpDtype.PY_DT_TUPLE_STR.value
|
||||
PY_DT_TUPLE_ANY = OpDtype.PY_DT_TUPLE_ANY.value
|
||||
DT_TUPLE_BOOL_VAL = int(OpDtype.DT_TUPLE_BOOL)
|
||||
DT_TUPLE_INT_VAL = int(OpDtype.DT_TUPLE_INT)
|
||||
DT_TUPLE_FLOAT_VAL = int(OpDtype.DT_TUPLE_FLOAT)
|
||||
DT_TUPLE_NUMBER_VAL = int(OpDtype.DT_TUPLE_NUMBER)
|
||||
DT_TUPLE_TENSOR_VAL = int(OpDtype.DT_TUPLE_TENSOR)
|
||||
DT_TUPLE_STR_VAL = int(OpDtype.DT_TUPLE_STR)
|
||||
DT_TUPLE_ANY_VAL = int(OpDtype.DT_TUPLE_ANY)
|
||||
# list
|
||||
PY_DT_LIST_BOOL = OpDtype.PY_DT_LIST_BOOL.value
|
||||
PY_DT_LIST_INT = OpDtype.PY_DT_LIST_INT.value
|
||||
PY_DT_LIST_FLOAT = OpDtype.PY_DT_LIST_FLOAT.value
|
||||
PY_DT_LIST_NUMBER = OpDtype.PY_DT_LIST_NUMBER.value
|
||||
PY_DT_LIST_TENSOR = OpDtype.PY_DT_LIST_TENSOR.value
|
||||
PY_DT_LIST_STR = OpDtype.PY_DT_LIST_STR.value
|
||||
PY_DT_LIST_ANY = OpDtype.PY_DT_LIST_ANY.value
|
||||
DT_LIST_BOOL_VAL = int(OpDtype.DT_LIST_BOOL)
|
||||
DT_LIST_INT_VAL = int(OpDtype.DT_LIST_INT)
|
||||
DT_LIST_FLOAT_VAL = int(OpDtype.DT_LIST_FLOAT)
|
||||
DT_LIST_NUMBER_VAL = int(OpDtype.DT_LIST_NUMBER)
|
||||
DT_LIST_TENSOR_VAL = int(OpDtype.DT_LIST_TENSOR)
|
||||
DT_LIST_STR_VAL = int(OpDtype.DT_LIST_STR)
|
||||
DT_LIST_ANY_VAL = int(OpDtype.DT_LIST_ANY)
|
||||
# tensor
|
||||
PY_DT_TENSOR = OpDtype.PY_DT_TENSOR.value
|
||||
|
||||
DT_TENSOR_VAL = int(OpDtype.DT_TENSOR)
|
||||
|
||||
dtype_to_string = {
|
||||
PY_DT_INT: "int",
|
||||
PY_DT_FLOAT: "float",
|
||||
PY_DT_BOOL: "bool",
|
||||
PY_DT_NUMBER: "number",
|
||||
PY_DT_TENSOR: "Tensor",
|
||||
PY_DT_TUPLE_BOOL: "tuple of bool",
|
||||
PY_DT_TUPLE_INT: "tuple of int",
|
||||
PY_DT_TUPLE_FLOAT: "tuple of float",
|
||||
PY_DT_TUPLE_NUMBER: "tuple of number",
|
||||
PY_DT_TUPLE_TENSOR: "tuple of tensor",
|
||||
PY_DT_TUPLE_STR: "tuple of string",
|
||||
PY_DT_TUPLE_ANY: "tuple of Any",
|
||||
PY_DT_LIST_BOOL: "list of bool",
|
||||
PY_DT_LIST_INT: "list of int",
|
||||
PY_DT_LIST_FLOAT: "list of float",
|
||||
PY_DT_LIST_NUMBER: "list of number",
|
||||
PY_DT_LIST_TENSOR: "list of Tensor",
|
||||
PY_DT_LIST_STR: "list of string",
|
||||
PY_DT_LIST_ANY: "list of Any"
|
||||
DT_INT_VAL: "int",
|
||||
DT_FLOAT_VAL: "float",
|
||||
DT_BOOL_VAL: "bool",
|
||||
DT_NUMBER_VAL: "number",
|
||||
DT_TENSOR_VAL: "Tensor",
|
||||
DT_TUPLE_BOOL_VAL: "tuple of bool",
|
||||
DT_TUPLE_INT_VAL: "tuple of int",
|
||||
DT_TUPLE_FLOAT_VAL: "tuple of float",
|
||||
DT_TUPLE_NUMBER_VAL: "tuple of number",
|
||||
DT_TUPLE_TENSOR_VAL: "tuple of tensor",
|
||||
DT_TUPLE_STR_VAL: "tuple of string",
|
||||
DT_TUPLE_ANY_VAL: "tuple of Any",
|
||||
DT_LIST_BOOL_VAL: "list of bool",
|
||||
DT_LIST_INT_VAL: "list of int",
|
||||
DT_LIST_FLOAT_VAL: "list of float",
|
||||
DT_LIST_NUMBER_VAL: "list of number",
|
||||
DT_LIST_TENSOR_VAL: "list of Tensor",
|
||||
DT_LIST_STR_VAL: "list of string",
|
||||
DT_LIST_ANY_VAL: "list of Any"
|
||||
}
|
||||
|
||||
|
||||
|
@ -122,34 +121,35 @@ def is_tuple(type_id):
|
|||
"""
|
||||
Check type id is tuple.
|
||||
"""
|
||||
return type_id in (PY_DT_TUPLE_BOOL, PY_DT_TUPLE_INT, PY_DT_TUPLE_FLOAT, PY_DT_TUPLE_NUMBER,
|
||||
PY_DT_TUPLE_TENSOR, PY_DT_TUPLE_STR, PY_DT_TUPLE_ANY)
|
||||
return type_id in (DT_TUPLE_BOOL_VAL, DT_TUPLE_INT_VAL, DT_TUPLE_FLOAT_VAL, DT_TUPLE_NUMBER_VAL,
|
||||
DT_TUPLE_TENSOR_VAL, DT_TUPLE_STR_VAL, DT_TUPLE_ANY_VAL)
|
||||
|
||||
|
||||
def is_list(type_id):
|
||||
"""
|
||||
Check type id is list.
|
||||
"""
|
||||
return type_id in (PY_DT_LIST_BOOL, PY_DT_LIST_INT, PY_DT_LIST_FLOAT, PY_DT_LIST_NUMBER, PY_DT_LIST_TENSOR,
|
||||
PY_DT_LIST_STR, PY_DT_LIST_ANY)
|
||||
return type_id in (DT_LIST_BOOL_VAL, DT_LIST_INT_VAL, DT_LIST_FLOAT_VAL, DT_LIST_NUMBER_VAL,
|
||||
DT_LIST_TENSOR_VAL,
|
||||
DT_LIST_STR_VAL, DT_LIST_ANY_VAL)
|
||||
|
||||
|
||||
def is_number(type_id):
|
||||
"""
|
||||
Check type id is number.
|
||||
"""
|
||||
return type_id in (PY_DT_INT, PY_DT_FLOAT, PY_DT_BOOL, PY_DT_NUMBER)
|
||||
return type_id in (DT_INT_VAL, DT_FLOAT_VAL, DT_BOOL_VAL, DT_NUMBER_VAL)
|
||||
|
||||
|
||||
def is_instance_of(data, type_id):
|
||||
"""
|
||||
Instead isinstance(obj, type).
|
||||
"""
|
||||
if type_id == PY_DT_INT:
|
||||
if type_id == DT_INT_VAL:
|
||||
return isinstance(data, int)
|
||||
if type_id == PY_DT_FLOAT:
|
||||
if type_id == DT_FLOAT_VAL:
|
||||
return isinstance(data, float)
|
||||
if type_id == PY_DT_BOOL:
|
||||
if type_id == DT_BOOL_VAL:
|
||||
return isinstance(data, bool)
|
||||
if is_number(type_id):
|
||||
return isinstance(data, (int, float, bool))
|
||||
|
@ -157,7 +157,7 @@ def is_instance_of(data, type_id):
|
|||
return isinstance(data, tuple)
|
||||
if is_list(type_id):
|
||||
return isinstance(data, list)
|
||||
if type_id == PY_DT_TENSOR:
|
||||
if type_id == DT_TENSOR_VAL:
|
||||
return isinstance(data, Tensor)
|
||||
return False
|
||||
|
||||
|
@ -192,7 +192,7 @@ def do_type_cast(data, dst_type):
|
|||
"""Type conversion."""
|
||||
if is_instance_of(data, dst_type):
|
||||
return data
|
||||
if dst_type == PY_DT_FLOAT:
|
||||
if dst_type == DT_FLOAT_VAL:
|
||||
if isinstance(data, int):
|
||||
return int_to_float(data)
|
||||
elif is_tuple(dst_type):
|
||||
|
@ -202,7 +202,7 @@ def do_type_cast(data, dst_type):
|
|||
return list_to_tuple(data)
|
||||
if isinstance(data, Tensor):
|
||||
return tensor_to_tuple(data)
|
||||
elif dst_type == PY_DT_TENSOR:
|
||||
elif dst_type == DT_TENSOR_VAL:
|
||||
if isinstance(data, (int, float, bool)):
|
||||
return scalar_to_tensor(data)
|
||||
if isinstance(data, tuple):
|
||||
|
@ -211,7 +211,7 @@ def do_type_cast(data, dst_type):
|
|||
return list_to_tensor(data)
|
||||
elif is_number(dst_type):
|
||||
if isinstance(data, Tensor):
|
||||
if dst_type == PY_DT_INT:
|
||||
if dst_type == DT_INT_VAL:
|
||||
data = ops.cast(data, ms.int64)
|
||||
ret = TensorToScalar()(data)
|
||||
return ret
|
||||
|
@ -222,6 +222,11 @@ def type_it(data, src_type, dst_type):
|
|||
"""
|
||||
cast operator argument data type.
|
||||
"""
|
||||
if not isinstance(src_type, tuple):
|
||||
src_type = int(src_type)
|
||||
else:
|
||||
src_type = tuple((int(t) for t in src_type))
|
||||
dst_type = int(dst_type)
|
||||
if not is_instance_in(data, src_type) and not is_instance_of(data, dst_type):
|
||||
support_list = get_support_dtype_list(src_type, dst_type)
|
||||
raise TypeError(f"For type conversion here, only support <{support_list}>, but got {type(data)}.")
|
||||
|
|
|
@ -14,10 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Operator argument handle function."""
|
||||
|
||||
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
|
||||
|
||||
ops_dtype_to_enum = DtypeToEnum()
|
||||
|
||||
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
|
||||
|
||||
def to_kernel_size(kernel_size):
|
||||
"""
|
||||
|
@ -131,9 +128,9 @@ def to_3d_paddings(pad):
|
|||
return pad
|
||||
raise ValueError(f"For arg 'pad', the value is invalid: {pad}.")
|
||||
|
||||
dtype_to_type_id = DtypeToEnum()
|
||||
|
||||
def dtype_to_enum(dtype):
|
||||
"""
|
||||
convert mindspore.dtype to enum.
|
||||
"""
|
||||
return ops_dtype_to_enum(dtype)
|
||||
# string to enum
|
||||
# A function for converting str type to enum type are written here,
|
||||
# but the backend supports str input, and converting str input to enum input is not necessary.
|
||||
str_to_enum = StringToEnum()
|
||||
|
|
|
@ -1,241 +0,0 @@
|
|||
|
||||
# #enum TypeId
|
||||
# type_id:
|
||||
# kTypeUnknown: 0
|
||||
# #Meta types.
|
||||
# kMetaTypeBegin: 0
|
||||
# kMetaTypeType: 1
|
||||
# kMetaTypeAny: 2
|
||||
# kMetaTypeObject: 3
|
||||
# kMetaTypeTypeType: 4
|
||||
# kMetaTypeProblem: 5
|
||||
# kMetaTypeExternal: 6
|
||||
# kMetaTypeNone: 7
|
||||
# kMetaTypeNull: 8
|
||||
# kMetaTypeEllipsis: 9
|
||||
# kMetaTypeEnd: 10
|
||||
# #
|
||||
# # Object types
|
||||
# #
|
||||
# kObjectTypeBegin: 10
|
||||
# kObjectTypeNumber: 11
|
||||
# kObjectTypeString: 12
|
||||
# kObjectTypeList: 13
|
||||
# kObjectTypeTuple: 14
|
||||
# kObjectTypeSlice: 15
|
||||
# kObjectTypeKeyword: 16
|
||||
# kObjectTypeTensorType: 17
|
||||
# kObjectTypeRowTensorType: 18
|
||||
# kObjectTypeCOOTensorType: 19
|
||||
# kObjectTypeUndeterminedType: 20
|
||||
# kObjectTypeClass: 21
|
||||
# kObjectTypeDictionary: 22
|
||||
# kObjectTypeFunction: 23
|
||||
# kObjectTypeJTagged: 24
|
||||
# kObjectTypeSymbolicKeyType: 25
|
||||
# kObjectTypeEnvType: 26
|
||||
# kObjectTypeRefKey: 27
|
||||
# kObjectTypeRef: 28
|
||||
# kObjectTypeEnd: 29
|
||||
# #
|
||||
# # Number Types
|
||||
# #
|
||||
# kNumberTypeBegin: 29
|
||||
# kNumberTypeBool: 30
|
||||
# kNumberTypeInt: 31
|
||||
# kNumberTypeInt8: 32
|
||||
# kNumberTypeInt16: 33
|
||||
# kNumberTypeInt32: 34
|
||||
# kNumberTypeInt64: 35
|
||||
# kNumberTypeUInt: 36
|
||||
# kNumberTypeUInt8: 37
|
||||
# kNumberTypeUInt16: 38
|
||||
# kNumberTypeUInt32: 39
|
||||
# kNumberTypeUInt64: 40
|
||||
# kNumberTypeFloat: 41
|
||||
# kNumberTypeFloat16: 42
|
||||
# kNumberTypeFloat32: 43
|
||||
# kNumberTypeFloat64: 44
|
||||
# kNumberTypeBFloat16: 45
|
||||
# kNumberTypeDouble: 46
|
||||
# kNumberTypeComplex: 47
|
||||
# kNumberTypeComplex64: 48
|
||||
# kNumberTypeComplex128: 49
|
||||
# kNumberTypeInt4: 50
|
||||
# kNumberTypeGLUInt: 51
|
||||
# kNumberTypeEnd: 52
|
||||
# #
|
||||
# # Monad Types
|
||||
# #
|
||||
# kMonadTypeBegin: 52
|
||||
# kObjectTypeMonad: 53
|
||||
# kObjectTypeUMonad: 54
|
||||
# kObjectTypeIOMonad: 55
|
||||
# kMonadTypeEnd: 56
|
||||
# #
|
||||
# # Sparse Types
|
||||
# #
|
||||
# kSparseTypeBegin: 56
|
||||
# kObjectTypeCSRTensorType: 57
|
||||
# kObjectTypeSparseTensorType: 58
|
||||
# kObjectTypeMapTensorType: 59
|
||||
# kSparseTypeEnd: 60
|
||||
|
||||
# #enum OpDtype
|
||||
op_dtype:
|
||||
PY_DT_BEGIN: 0
|
||||
PY_DT_BOOL: 1
|
||||
PY_DT_INT: 2
|
||||
PY_DT_FLOAT: 3
|
||||
PY_DT_NUMBER: 4
|
||||
PY_DT_TENSOR: 5
|
||||
PY_DT_STR: 6
|
||||
PY_DT_ANY: 7
|
||||
PY_DT_TUPLE_BOOL: 8
|
||||
PY_DT_TUPLE_INT: 9
|
||||
PY_DT_TUPLE_FLOAT: 10
|
||||
PY_DT_TUPLE_NUMBER: 11
|
||||
PY_DT_TUPLE_TENSOR: 12
|
||||
PY_DT_TUPLE_STR: 13
|
||||
PY_DT_TUPLE_ANY: 14
|
||||
PY_DT_LIST_BOOL: 15
|
||||
PY_DT_LIST_INT: 16
|
||||
PY_DT_LIST_FLOAT: 17
|
||||
PY_DT_LIST_NUMBER: 18
|
||||
PY_DT_LIST_TENSOR: 19
|
||||
PY_DT_LIST_STR: 20
|
||||
PY_DT_LIST_ANY: 21
|
||||
PY_DT_TYPE: 22
|
||||
PY_DT_END: 23
|
||||
|
||||
# enum Format
|
||||
# the value must be consistent with enum defined in `mindspore/core/mindapi/base/format.h`
|
||||
py_format:
|
||||
NCHW: 0
|
||||
NHWC: 1
|
||||
NHWC4: 2
|
||||
HWKC: 3
|
||||
HWCK: 4
|
||||
KCHW: 5
|
||||
CKHW: 6
|
||||
KHWC: 7
|
||||
CHWK: 8
|
||||
HW: 9
|
||||
HW4: 10
|
||||
NC: 11
|
||||
NC4: 12
|
||||
NC4HW4: 13
|
||||
NCDHW: 14
|
||||
NWC: 15
|
||||
NCW: 16
|
||||
NDHWC: 17
|
||||
NC8HW8: 18
|
||||
FRACTAL_NZ: 19
|
||||
ND: 20
|
||||
NC1HWC0: 21
|
||||
FRACTAL_Z: 22
|
||||
NC1C0HWPAD: 23
|
||||
NHWC1C0: 24
|
||||
FSR_NCHW: 25
|
||||
FRACTAL_DECONV: 26
|
||||
C1HWNC0: 27
|
||||
FRACTAL_DECONV_TRANSPOSE: 28
|
||||
FRACTAL_DECONV_SP_STRIDE_TRANS: 29
|
||||
NC1HWC0_C04: 30
|
||||
FRACTAL_Z_C04: 31
|
||||
CHWN: 32
|
||||
FRACTAL_DECONV_SP_STRIDE8_TRANS: 33
|
||||
HWCN: 34
|
||||
NC1KHKWHWC0: 35
|
||||
BN_WEIGHT: 36
|
||||
FILTER_HWCK: 37
|
||||
LOOKUP_LOOKUPS: 38
|
||||
LOOKUP_KEYS: 39
|
||||
LOOKUP_VALUE: 40
|
||||
LOOKUP_OUTPUT: 41
|
||||
LOOKUP_HITS: 42
|
||||
C1HWNCoC0: 43
|
||||
MD: 44
|
||||
FRACTAL_ZZ: 45
|
||||
DHWCN: 46
|
||||
NDC1HWC0: 47
|
||||
FRACTAL_Z_3D: 48
|
||||
CN: 49
|
||||
DHWNC: 50
|
||||
FRACTAL_Z_3D_TRANSPOSE: 51
|
||||
FRACTAL_ZN_LSTM: 52
|
||||
FRACTAL_Z_G: 53
|
||||
ND_RNN_BIAS: 54
|
||||
FRACTAL_ZN_RNN: 55
|
||||
NYUV: 56
|
||||
NYUV_A: 57
|
||||
NCL: 58
|
||||
|
||||
# enum PadMode
|
||||
pad_mode:
|
||||
PAD: 0
|
||||
SAME: 1
|
||||
VALID: 2
|
||||
|
||||
# enum Reduction
|
||||
reduction:
|
||||
NONE: 0
|
||||
MEAN: 1
|
||||
SUM: 2
|
||||
ADD: 3
|
||||
|
||||
# enum Direction
|
||||
direction:
|
||||
UNIDIRECTIONAL: 0
|
||||
|
||||
# enum Activation
|
||||
activation:
|
||||
TANH: 0
|
||||
|
||||
# enum GateOrder
|
||||
gate_order:
|
||||
RZH: 0
|
||||
ZRH: 1
|
||||
|
||||
# enum CellType
|
||||
cell_type:
|
||||
LSTM: 0
|
||||
|
||||
# enum Group
|
||||
group:
|
||||
SYNC_BN_GROUP0: 0
|
||||
|
||||
# enum InterpolationMode
|
||||
interpolation_mode:
|
||||
BILINEAR: 0
|
||||
NEAREST: 1
|
||||
|
||||
# enum NormMode
|
||||
norm_mode:
|
||||
BACKWARD: 0
|
||||
FORWARD: 1
|
||||
ORTHO: 2
|
||||
|
||||
# enum GridSamplerPaddingMode
|
||||
grid_sampler_padding_mode:
|
||||
ZEROS: 0
|
||||
BORDER: 1
|
||||
REFLECTION: 2
|
||||
|
||||
# enum PadFillingMode
|
||||
pad_filling_mode:
|
||||
CONSTANT: 0
|
||||
REFLECT: 1
|
||||
EDGE: 2
|
||||
CIRCULAR: 3
|
||||
SYMMETRIC: 4
|
||||
|
||||
# enum CoordinateTransformationMode
|
||||
coordinate_transformation_mode:
|
||||
ALIGN_CORNERS: 0
|
||||
HALF_PIXEL: 1
|
||||
|
||||
# enum KVCacheAlignMode
|
||||
k_v_cache_align_mode:
|
||||
RIGHT: 0
|
||||
LEFT: 1
|
|
@ -20,7 +20,8 @@ import re
|
|||
import shutil
|
||||
import pathlib
|
||||
import gen_utils
|
||||
from gen_utils import py_licence_str, cc_license_str, check_change_and_replace_file, merge_files, safe_load_yaml
|
||||
from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files,
|
||||
safe_load_yaml, convert_dtype_str)
|
||||
from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes
|
||||
import template
|
||||
from template import CppTemplate
|
||||
|
@ -353,6 +354,12 @@ def {func_name}({', '.join(arg for arg in func_args)}):
|
|||
|
||||
return gen_py
|
||||
|
||||
def get_dtype(arg_info):
|
||||
dtype = arg_info.get('dtype')
|
||||
# Currently, TypeId is represented by int
|
||||
if dtype == 'TypeId':
|
||||
dtype = 'int'
|
||||
return dtype
|
||||
|
||||
def process_args(args):
|
||||
"""
|
||||
|
@ -365,7 +372,7 @@ def process_args(args):
|
|||
init_args_with_default = []
|
||||
args_handlers = {}
|
||||
for arg_name, arg_info in args.items():
|
||||
dtype = arg_info.get('dtype')
|
||||
dtype = get_dtype(arg_info)
|
||||
default_value = arg_info.get('default')
|
||||
has_default = 'default' in arg_info.keys()
|
||||
is_optional = arg_info.get('default') == "None" if has_default else False
|
||||
|
@ -621,7 +628,7 @@ namespace mindspore::ops {
|
|||
if not is_prim_init:
|
||||
continue
|
||||
|
||||
dtype = arg_info.get('dtype')
|
||||
dtype = get_dtype(arg_info)
|
||||
if dtype == "str":
|
||||
dtype = "std::string"
|
||||
if dtype == "tuple[int]":
|
||||
|
@ -671,8 +678,8 @@ std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
|
|||
for i, (arg_name, arg_info) in enumerate(args.items()):
|
||||
args_dict[arg_name] = i
|
||||
cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
|
||||
dtype = arg_info.get('dtype')
|
||||
cc_dtype_str = 'DT_' + dtype.replace('[', '_').replace(']', '').upper()
|
||||
dtype = get_dtype(arg_info)
|
||||
cc_dtype_str = convert_dtype_str(dtype)
|
||||
|
||||
is_prim_init = 1 if arg_info.get('prim_init') else 0
|
||||
arg_handler = arg_info.get('arg_handler')
|
||||
|
@ -723,7 +730,7 @@ from mindspore.common._decorator import deprecated
|
|||
from mindspore.ops._primitive_cache import _get_cache_prim
|
||||
from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it
|
||||
from mindspore.ops.auto_generate.gen_arg_handler import *
|
||||
from mindspore.ops.auto_generate.gen_enum_def import OpDtype
|
||||
from mindspore._c_expression import OpDtype
|
||||
from mindspore.common._stub_tensor import _convert_stub
|
||||
"""
|
||||
|
||||
|
@ -891,67 +898,10 @@ def generate_aclnn_reg_file(work_path, yaml_str):
|
|||
check_change_and_replace_file(register_file, tmp_register_file)
|
||||
|
||||
|
||||
eum_py_header = f"""
|
||||
\"\"\"Operator argument enum definition.\"\"\"
|
||||
|
||||
from enum import Enum
|
||||
"""
|
||||
|
||||
eum_cc_header = f"""
|
||||
#ifndef MINDSPORE_CORE_OPS_GEN_ENUM_DEF_
|
||||
#define MINDSPORE_CORE_OPS_GEN_ENUM_DEF_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace mindspore::MsPyEnum {{
|
||||
"""
|
||||
|
||||
eum_cc_end = f"""}} // namespace mindspore::MsPyEnum
|
||||
#endif // MINDSPORE_CORE_OPS_GEN_ENUM_DEF_
|
||||
"""
|
||||
|
||||
|
||||
def generate_enum_code(yaml_data):
|
||||
def generate_arg_handler_files(work_path):
|
||||
"""
|
||||
Generate python and c++ enum definition
|
||||
Generate arg handler files.
|
||||
"""
|
||||
gen_eum_py_func = ''
|
||||
gen_eum_py_def = eum_py_header
|
||||
gen_eum_cc_def = eum_cc_header
|
||||
for enum_name, enum_data in yaml_data.items():
|
||||
class_name = ''.join(word.capitalize() for word in enum_name.split('_'))
|
||||
gen_eum_py_func += f"""\n
|
||||
def {enum_name}_to_enum({enum_name}_str):
|
||||
\"""
|
||||
convert {enum_name} string to enum.
|
||||
\"""
|
||||
if not isinstance({enum_name}_str, str):
|
||||
raise TypeError(f"The {enum_name} should be string, but got {{{enum_name}_str}}")
|
||||
{enum_name}_str = {enum_name}_str.upper()\n"""
|
||||
gen_eum_py_def += f"""\n\nclass {class_name}(Enum):\n"""
|
||||
gen_eum_cc_def += f"""enum {class_name} : int64_t {{\n"""
|
||||
|
||||
for enum_key, enum_value in enum_data.items():
|
||||
gen_eum_py_func += f""" if {enum_name}_str == "{enum_key}":
|
||||
return {enum_value}\n"""
|
||||
gen_eum_py_def += f""" {enum_key} = {enum_value}\n"""
|
||||
gen_eum_cc_def += f""" {enum_key} = {enum_value},\n"""
|
||||
|
||||
gen_eum_py_func += f""" raise ValueError(f"Invalid {class_name}: {{{enum_name}_str}}")\n"""
|
||||
gen_eum_cc_def += f"""}};\n\n"""
|
||||
gen_eum_cc_def += eum_cc_end
|
||||
|
||||
return gen_eum_py_func, gen_eum_py_def, gen_eum_cc_def
|
||||
|
||||
|
||||
def generate_enum_files(work_path):
|
||||
"""
|
||||
Generate python function and c++ definition from enum yaml.
|
||||
"""
|
||||
enum_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/enum.yaml')
|
||||
yaml_str = safe_load_yaml(enum_yaml_path)
|
||||
py_enum_func, py_enum_def, cc_enum_def = generate_enum_code(yaml_str)
|
||||
|
||||
dst_dir = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate')
|
||||
src_arg_handler_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_handler.py')
|
||||
dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py')
|
||||
|
@ -959,8 +909,6 @@ def generate_enum_files(work_path):
|
|||
if not os.path.exists(dst_dir):
|
||||
os.makedirs(dst_dir)
|
||||
shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path)
|
||||
with open(tmp_dst_arg_handler_path, 'a') as py_file:
|
||||
py_file.write(py_enum_func)
|
||||
check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path)
|
||||
|
||||
src_arg_dtype_cast_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_dtype_cast.py')
|
||||
|
@ -969,18 +917,6 @@ def generate_enum_files(work_path):
|
|||
shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
|
||||
check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
|
||||
|
||||
enum_def_py_path = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate/gen_enum_def.py')
|
||||
tmp_enum_def_py_path = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate/tmp_gen_enum_def.py')
|
||||
with open(tmp_enum_def_py_path, 'w') as cc_file:
|
||||
cc_file.write(py_licence_str + py_enum_def)
|
||||
check_change_and_replace_file(enum_def_py_path, tmp_enum_def_py_path)
|
||||
|
||||
enum_def_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_enum_def.h')
|
||||
tmp_enum_def_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_enum_def.h')
|
||||
with open(tmp_enum_def_cc_path, 'w') as cc_file:
|
||||
cc_file.write(cc_license_str + cc_enum_def)
|
||||
check_change_and_replace_file(enum_def_cc_path, tmp_enum_def_cc_path)
|
||||
|
||||
|
||||
def main():
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -1005,8 +941,8 @@ def main():
|
|||
cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/')
|
||||
pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# generate enum code from enum.yaml
|
||||
generate_enum_files(work_path)
|
||||
# generate arg_handler files
|
||||
generate_arg_handler_files(work_path)
|
||||
|
||||
# generate ops python files
|
||||
generate_ops_py_files(work_path, safe_load_yaml(ops_yaml_path), safe_load_yaml(doc_yaml_path), "gen")
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
from mindspore.ops.primitive import Primitive, prim_attr_register
|
||||
from mindspore._c_expression import typing
|
||||
from mindspore._c_expression import op_enum
|
||||
|
||||
|
||||
class DtypeToEnum(Primitive):
|
||||
|
@ -41,3 +42,28 @@ class DtypeToEnum(Primitive):
|
|||
if not isinstance(dtype, typing.Type):
|
||||
raise TypeError(f"For dtype_to_enum function, the input should be mindpsore dtype, but got {dtype}.")
|
||||
return typing.type_to_type_id(dtype)
|
||||
|
||||
|
||||
class StringToEnum(Primitive):
|
||||
r"""
|
||||
Convert string to enum.
|
||||
|
||||
Inputs:
|
||||
- **enum_str** (str) - The str data.
|
||||
|
||||
Outputs:
|
||||
An integer.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize"""
|
||||
|
||||
def __call__(self, enum_str):
|
||||
"""Run in PyNative mode"""
|
||||
if not isinstance(enum_str, str):
|
||||
raise TypeError(f"For StringToEnum op, the input should be a str, but got {type(enum_str)}.")
|
||||
return op_enum.str_to_enum(enum_str)
|
||||
|
|
|
@ -53,30 +53,36 @@ cc_license_str = f"""/**
|
|||
* limitations under the License.
|
||||
*/"""
|
||||
|
||||
def convert_dtype_str(dtype_str):
|
||||
"""
|
||||
Convert dtype str to expression in ops file
|
||||
"""
|
||||
return 'DT_' + dtype_str.replace('[', '_').replace(']', '').upper()
|
||||
|
||||
|
||||
def get_type_str(type_str):
|
||||
"""
|
||||
Get the unified type str for operator arg dtype.
|
||||
"""
|
||||
# add more type here
|
||||
type_kind_dict = {
|
||||
'int': 'OpDtype.PY_DT_INT',
|
||||
'float': 'OpDtype.PY_DT_FLOAT',
|
||||
'bool': 'OpDtype.PY_DT_BOOL',
|
||||
'number': 'OpDtype.PY_DT_NUMBER',
|
||||
'tuple[int]': 'OpDtype.PY_DT_TUPLE_ANY',
|
||||
'tuple[float]': 'OpDtype.PY_DT_TUPLE_ANY',
|
||||
'tuple[bool]': 'OpDtype.PY_DT_TUPLE_ANY',
|
||||
'tuple[tensor]': 'OpDtype.PY_DT_TUPLE_ANY',
|
||||
'list[int]': 'OpDtype.PY_DT_LIST_ANY',
|
||||
'list[float]': 'OpDtype.PY_DT_LIST_ANY',
|
||||
'list[bool]': 'OpDtype.PY_DT_LIST_ANY',
|
||||
'list[tensor]': 'OpDtype.PY_DT_LIST_ANY',
|
||||
'tensor': 'OpDtype.PY_DT_TENSOR',
|
||||
'type': 'OpDtype.PY_DT_TYPE',
|
||||
type_kind_set = {
|
||||
'int',
|
||||
'float',
|
||||
'bool',
|
||||
'number',
|
||||
'tuple[int]',
|
||||
'tuple[float]',
|
||||
'tuple[bool]',
|
||||
'tuple[tensor]',
|
||||
'list[int]',
|
||||
'list[float]',
|
||||
'list[bool]',
|
||||
'list[tensor]',
|
||||
'tensor',
|
||||
'type',
|
||||
}
|
||||
if type_str in type_kind_dict:
|
||||
return type_kind_dict[type_str]
|
||||
if type_str in type_kind_set:
|
||||
return "OpDtype." + convert_dtype_str(type_str)
|
||||
raise TypeError(f"""Unsupported type {type_str} for args.""")
|
||||
|
||||
|
||||
|
@ -158,10 +164,10 @@ def get_assign_str_by_type_it(arg_info, arg_name, dtype):
|
|||
type_cast_tuple = tuple(ct.strip() for ct in type_cast.split(","))
|
||||
assign_str += f'type_it({arg_name}, '
|
||||
if len(type_cast_tuple) == 1:
|
||||
assign_str += get_type_str(type_cast_tuple[0]) + '.value, '
|
||||
assign_str += get_type_str(type_cast_tuple[0]) + ', '
|
||||
else:
|
||||
assign_str += '(' + ', '.join(get_type_str(ct) + '.value' for ct in type_cast_tuple) + '), '
|
||||
assign_str += get_type_str(dtype) + '.value)'
|
||||
assign_str += '(' + ', '.join(get_type_str(ct) for ct in type_cast_tuple) + '), '
|
||||
assign_str += get_type_str(dtype) + ')'
|
||||
else:
|
||||
assign_str = arg_name
|
||||
return assign_str
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "ops/test_value_utils.h"
|
||||
#include "ops/auto_generate/gen_ops_name.h"
|
||||
#include "ops/auto_generate/gen_enum_def.h"
|
||||
#include "ops/op_enum.h"
|
||||
#include "ops/test_ops.h"
|
||||
#include "ops/ops_func_impl/fft_with_size.h"
|
||||
|
||||
|
@ -42,7 +42,7 @@ TEST_P(TestFFT, dyn_shape) {
|
|||
auto signal_ndim = param.signal_ndim == -1 ? Any->ToAbstract() : CreatePyInt(param.signal_ndim)->ToAbstract();
|
||||
auto inverse = CreateScalar(param.inverse)->ToAbstract();
|
||||
auto real = CreateScalar(param.real)->ToAbstract();
|
||||
auto norm = CreateScalar(static_cast<int64_t>(MsPyEnum::NormMode::BACKWARD))->ToAbstract();
|
||||
auto norm = CreateScalar(static_cast<int64_t>(ops::NormMode::BACKWARD))->ToAbstract();
|
||||
auto onesided = CreateScalar(param.onesided)->ToAbstract();
|
||||
auto signal_sizes = param.signal_sizes->ToAbstract();
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "ops/test_value_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -99,31 +100,238 @@ TEST_P(TestNLLLoss, dyn_shape) {
|
|||
INSTANTIATE_TEST_CASE_P(
|
||||
TestNLLLossGroup, TestNLLLoss,
|
||||
testing::Values(
|
||||
NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3}, kFloat32, {-1}, kInt32, {3}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2}, kFloat32, {-2}, kInt32, {-2}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 0, true, {{2}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
true,
|
||||
{{-1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::MEAN),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::REDUCTION_SUM),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3},
|
||||
kFloat32,
|
||||
{-2},
|
||||
kInt32,
|
||||
{-2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
true,
|
||||
{{-1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3},
|
||||
kFloat32,
|
||||
{-2},
|
||||
kInt32,
|
||||
{-2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::MEAN),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, 3},
|
||||
kFloat32,
|
||||
{-2},
|
||||
kInt32,
|
||||
{-2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::REDUCTION_SUM),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2},
|
||||
kFloat32,
|
||||
{-2},
|
||||
kInt32,
|
||||
{-2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
true,
|
||||
{{-1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2},
|
||||
kFloat32,
|
||||
{-2},
|
||||
kInt32,
|
||||
{-2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::MEAN),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2},
|
||||
kFloat32,
|
||||
{-2},
|
||||
kInt32,
|
||||
{-2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::REDUCTION_SUM),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3},
|
||||
kFloat32,
|
||||
{2},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
true,
|
||||
{{2}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3},
|
||||
kFloat32,
|
||||
{2},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::MEAN),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3},
|
||||
kFloat32,
|
||||
{2},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::REDUCTION_SUM),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {3}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, -1},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{-1},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
true,
|
||||
{{-1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, -1},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{-1},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::MEAN),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, -1},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{-1},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::REDUCTION_SUM),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-1, -1}, kFloat32, {-1}, kInt32, {-1}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 0, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 1, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, 2, true, {{}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{-1},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
true,
|
||||
{{-1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{-1},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::MEAN),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2},
|
||||
kFloat32,
|
||||
{-1},
|
||||
kInt32,
|
||||
{-1},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::REDUCTION_SUM),
|
||||
true,
|
||||
{{}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{-2}, kFloat32, {-1}, kInt32, {-1}, kFloat32, -1, true, {{-1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3, 4}, kFloat32, {2}, kInt32, {3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {2, 3}, kInt32, {3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {2, 3}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {2}, kInt32, {2}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3}, kFloat32, {3}, kInt32, {2}, kFloat32, 0, false, {{1}, {}}, {kFloat32, kFloat32}}));
|
||||
NLLLossParams{{2, 3, 4},
|
||||
kFloat32,
|
||||
{2},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
false,
|
||||
{{1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3},
|
||||
kFloat32,
|
||||
{2, 3},
|
||||
kInt32,
|
||||
{3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
false,
|
||||
{{1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3},
|
||||
kFloat32,
|
||||
{2},
|
||||
kInt32,
|
||||
{2, 3},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
false,
|
||||
{{1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3},
|
||||
kFloat32,
|
||||
{2},
|
||||
kInt32,
|
||||
{2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
false,
|
||||
{{1}, {}},
|
||||
{kFloat32, kFloat32}},
|
||||
NLLLossParams{{2, 3},
|
||||
kFloat32,
|
||||
{3},
|
||||
kInt32,
|
||||
{2},
|
||||
kFloat32,
|
||||
static_cast<int64_t>(Reduction::NONE),
|
||||
false,
|
||||
{{1}, {}},
|
||||
{kFloat32, kFloat32}}));
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue