Clear the warning information under the ops directory

This commit is contained in:
shen_jingxing 2021-09-28 09:40:39 +08:00
parent 70c8263a17
commit bda235ca77
17 changed files with 112 additions and 107 deletions

View File

@ -20,6 +20,7 @@
"mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc" "containerOutOfBounds" "mindspore/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc" "containerOutOfBounds"
"mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc" "containerOutOfBounds" "mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc" "containerOutOfBounds"
"mindspore/mindspore/core/ops/strided_slice.cc" "zerodivcond" "mindspore/mindspore/core/ops/strided_slice.cc" "zerodivcond"
"mindspore/mindspore/core/ops/avg_pool_3d.cc" "zerodivcond"
"mindspore/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc" "useStlAlgorithm" "mindspore/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc" "useStlAlgorithm"
"mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro" "mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro"

View File

@ -25,8 +25,7 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
namespace { namespace {
abstract::TupleShapePtr LayerNormBetaGammaBackpropInferShape(const PrimitivePtr &primitive, abstract::TupleShapePtr LayerNormBetaGammaBackpropInferShape(const PrimitivePtr &primitive) {
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
ValuePtr gamma_value_ptr = primitive->GetAttr("shape_gamma"); ValuePtr gamma_value_ptr = primitive->GetAttr("shape_gamma");
MS_EXCEPTION_IF_NULL(gamma_value_ptr); MS_EXCEPTION_IF_NULL(gamma_value_ptr);
@ -56,7 +55,7 @@ AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePt
const int64_t input_num = 4; const int64_t input_num = 4;
(void)CheckAndConvertUtils::CheckInteger("LayerNormBetaGammaBackprop infer", SizeToLong(input_args.size()), (void)CheckAndConvertUtils::CheckInteger("LayerNormBetaGammaBackprop infer", SizeToLong(input_args.size()),
kGreaterEqual, input_num, primitive->name()); kGreaterEqual, input_num, primitive->name());
return abstract::MakeAbstract(LayerNormBetaGammaBackpropInferShape(primitive, input_args), return abstract::MakeAbstract(LayerNormBetaGammaBackpropInferShape(primitive),
LayerNormBetaGammaBackpropInferType(primitive, input_args)); LayerNormBetaGammaBackpropInferType(primitive, input_args));
} }
REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormBetaGammaBackprop, prim::kPrimLayerNormBetaGammaBackprop, REGISTER_PRIMITIVE_EVAL_IMPL(LayerNormBetaGammaBackprop, prim::kPrimLayerNormBetaGammaBackprop,

View File

@ -32,15 +32,13 @@ void ImpleAbs(void *origin, void *target, size_t size) {
MS_EXCEPTION_IF_NULL(target); MS_EXCEPTION_IF_NULL(target);
auto origin_data = reinterpret_cast<T *>(origin); auto origin_data = reinterpret_cast<T *>(origin);
auto target_data = reinterpret_cast<T *>(target); auto target_data = reinterpret_cast<T *>(target);
MS_EXCEPTION_IF_NULL(origin_data);
MS_EXCEPTION_IF_NULL(target_data);
auto zero_val = static_cast<T>(0); auto zero_val = static_cast<T>(0);
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
target_data[i] = origin_data[i] >= zero_val ? origin_data[i] : -origin_data[i]; target_data[i] = origin_data[i] >= zero_val ? origin_data[i] : -origin_data[i];
} }
} }
abstract::ShapePtr AbsInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr AbsInferShape(const std::vector<AbstractBasePtr> &input_args) {
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }
@ -57,7 +55,7 @@ AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
const int64_t input_num = 1; const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return abstract::MakeAbstract(AbsInferShape(primitive, input_args), AbsInferType(primitive, input_args)); return abstract::MakeAbstract(AbsInferShape(input_args), AbsInferType(primitive, input_args));
} }
ValuePtr AbsInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { ValuePtr AbsInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
@ -77,53 +75,53 @@ ValuePtr AbsInferValue(const PrimitivePtr &prim, const std::vector<AbstractBaseP
auto data_size = x_tensor->DataSize(); auto data_size = x_tensor->DataSize();
auto dtype = x_tensor->data_type(); auto dtype = x_tensor->data_type();
auto shape = AbsInferShape(prim, input_args); auto shape = AbsInferShape(input_args);
auto result_tensor = std::make_shared<tensor::Tensor>(dtype, shape->shape()); auto result_tensor = std::make_shared<tensor::Tensor>(dtype, shape->shape());
auto x_datac = x_tensor->data_c(); auto x_datac = x_tensor->data_c();
auto result_datac = result_tensor->data_c(); auto result_datac = result_tensor->data_c();
switch (dtype) { switch (dtype) {
case kNumberTypeInt8: { case kNumberTypeInt8: {
ImpleAbs<int8_t>(x_datac, result_datac, data_size); ImpleAbs<int8_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeInt16: { case kNumberTypeInt16: {
ImpleAbs<int16_t>(x_datac, result_datac, data_size); ImpleAbs<int16_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeInt32: { case kNumberTypeInt32: {
ImpleAbs<int32_t>(x_datac, result_datac, data_size); ImpleAbs<int32_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeInt64: { case kNumberTypeInt64: {
ImpleAbs<int64_t>(x_datac, result_datac, data_size); ImpleAbs<int64_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt8: { case kNumberTypeUInt8: {
ImpleAbs<uint8_t>(x_datac, result_datac, data_size); ImpleAbs<uint8_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt16: { case kNumberTypeUInt16: {
ImpleAbs<uint16_t>(x_datac, result_datac, data_size); ImpleAbs<uint16_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt32: { case kNumberTypeUInt32: {
ImpleAbs<uint32_t>(x_datac, result_datac, data_size); ImpleAbs<uint32_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt64: { case kNumberTypeUInt64: {
ImpleAbs<uint64_t>(x_datac, result_datac, data_size); ImpleAbs<uint64_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeFloat16: { case kNumberTypeFloat16: {
ImpleAbs<float16>(x_datac, result_datac, data_size); ImpleAbs<float16>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeFloat32: { case kNumberTypeFloat32: {
ImpleAbs<float>(x_datac, result_datac, data_size); ImpleAbs<float>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeFloat64: { case kNumberTypeFloat64: {
ImpleAbs<double>(x_datac, result_datac, data_size); ImpleAbs<double>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
default: { default: {

View File

@ -44,10 +44,11 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
auto l2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape]; auto l2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
auto global_step_shape = auto global_step_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape]; CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, 0, primitive->name()); const int64_t input_num_ = 0;
(void)CheckAndConvertUtils::CheckInteger("l1_shape size", l1_shape.size(), kEqual, 0, primitive->name()); (void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, input_num_, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("l2_shape size", l2_shape.size(), kEqual, 0, primitive->name()); (void)CheckAndConvertUtils::CheckInteger("l1_shape size", l1_shape.size(), kEqual, input_num_, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("global_step_shape size", global_step_shape.size(), kEqual, 0, (void)CheckAndConvertUtils::CheckInteger("l2_shape size", l2_shape.size(), kEqual, input_num_, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("global_step_shape size", global_step_shape.size(), kEqual, input_num_,
primitive->name()); primitive->name());
return std::make_shared<abstract::TupleShape>( return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape, gradient_accumulator_shape, gradient_squared_accumulator_shape}); std::vector<abstract::BaseShapePtr>{var_shape, gradient_accumulator_shape, gradient_squared_accumulator_shape});
@ -72,25 +73,25 @@ TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr>
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// gradient_accumulator、gradient_squared_accumulator、grad must have the same type as var // gradient_accumulator、gradient_squared_accumulator、grad must have the same type as var
std::map<std::string, TypePtr> args; std::map<std::string, TypePtr> args;
args.insert({"var_type", var_type}); (void)args.insert({"var_type", var_type});
args.insert({"gradient_accumulator_type", gradient_accumulator_type}); (void)args.insert({"gradient_accumulator_type", gradient_accumulator_type});
args.insert({"gradient_squared_accumulator_type", gradient_squared_accumulator_type}); (void)args.insert({"gradient_squared_accumulator_type", gradient_squared_accumulator_type});
args.insert({"grad_type", grad_type}); (void)args.insert({"grad_type", grad_type});
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
// lr、l1、l2、global_step_type must be a scalar type // lr、l1、l2、global_step_type must be a scalar type
std::map<std::string, TypePtr> args_lr; std::map<std::string, TypePtr> args_lr;
std::map<std::string, TypePtr> args_l1; std::map<std::string, TypePtr> args_l1;
std::map<std::string, TypePtr> args_l2; std::map<std::string, TypePtr> args_l2;
std::map<std::string, TypePtr> args_global_step; std::map<std::string, TypePtr> args_global_step;
args_lr.insert({"lr_type", lr_type}); (void)args_lr.insert({"lr_type", lr_type});
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
args_l1.insert({"l1_type", l1_type}); (void)args_l1.insert({"l1_type", l1_type});
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l1, valid_types, prim_name); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l1, valid_types, prim_name);
args_l2.insert({"l2_type", l2_type}); (void)args_l2.insert({"l2_type", l2_type});
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l2, valid_types, prim_name); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l2, valid_types, prim_name);
args_global_step.insert({"global_step_type", global_step_type}); (void)args_global_step.insert({"global_step_type", global_step_type});
const std::set<TypePtr> valid_types1 = {kInt32, kInt64}; const std::set<TypePtr> valid_types1 = {kInt32, kInt64};
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_global_step, valid_types1, prim_name); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_global_step, valid_types1, prim_name);
return std::make_shared<Tuple>( return std::make_shared<Tuple>(
std::vector<TypePtr>{var_type, gradient_accumulator_type, gradient_squared_accumulator_type}); std::vector<TypePtr>{var_type, gradient_accumulator_type, gradient_squared_accumulator_type});
} }

View File

@ -32,8 +32,7 @@ constexpr size_t kStridesDims = 3;
constexpr size_t kPadDims = 6; constexpr size_t kPadDims = 6;
void GetAttrs(const PrimitivePtr &primitive, std::vector<int64_t> *kernel_size, std::vector<int64_t> *strides, void GetAttrs(const PrimitivePtr &primitive, std::vector<int64_t> *kernel_size, std::vector<int64_t> *strides,
int64_t *pad_mode, std::vector<int64_t> *pad_list, bool *ceil_mode, bool *count_include_pad, int64_t *pad_mode, std::vector<int64_t> *pad_list, bool *ceil_mode, bool *count_include_pad) {
int64_t *divisor_override) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
// attr kernel size // attr kernel size
*kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); *kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
@ -56,8 +55,6 @@ void GetAttrs(const PrimitivePtr &primitive, std::vector<int64_t> *kernel_size,
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), pad_mode, true); CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), pad_mode, true);
// attr ceil mode // attr ceil mode
*ceil_mode = GetValue<bool>(primitive->GetAttr(kCeilMode)); *ceil_mode = GetValue<bool>(primitive->GetAttr(kCeilMode));
// attr divisor override
*divisor_override = GetValue<int64_t>(primitive->GetAttr(kDivisorOverride));
} }
std::vector<int64_t> GetOutputShape(const std::vector<int64_t> &in_shape, int64_t kernel_d, int64_t kernel_h, std::vector<int64_t> GetOutputShape(const std::vector<int64_t> &in_shape, int64_t kernel_d, int64_t kernel_h,
@ -70,9 +67,12 @@ std::vector<int64_t> GetOutputShape(const std::vector<int64_t> &in_shape, int64_
int64_t out_h = 0; int64_t out_h = 0;
int64_t out_w = 0; int64_t out_w = 0;
if (ceil_mode) { if (ceil_mode) {
out_d = std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d + stride_d - 1) / stride_d + 1); out_d =
out_h = std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h + stride_h - 1) / stride_h + 1); static_cast<int64_t>(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d + stride_d - 1) / stride_d + 1));
out_w = std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w + stride_w - 1) / stride_w + 1); out_h =
static_cast<int64_t>(std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h + stride_h - 1) / stride_h + 1));
out_w =
static_cast<int64_t>(std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w + stride_w - 1) / stride_w + 1));
if ((out_d - 1) * stride_d >= in_d + pad_list[0]) { if ((out_d - 1) * stride_d >= in_d + pad_list[0]) {
out_d--; out_d--;
} }
@ -83,9 +83,9 @@ std::vector<int64_t> GetOutputShape(const std::vector<int64_t> &in_shape, int64_
out_w--; out_w--;
} }
} else { } else {
out_d = std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1); out_d = static_cast<int64_t>(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1));
out_h = std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h) / stride_h + 1); out_h = static_cast<int64_t>(std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h) / stride_h + 1));
out_w = std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w) / stride_w + 1); out_w = static_cast<int64_t>(std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w) / stride_w + 1));
} }
std::vector<int64_t> output_shape = {in_shape[0], in_shape[1], out_d, out_h, out_w}; std::vector<int64_t> output_shape = {in_shape[0], in_shape[1], out_d, out_h, out_w};
return output_shape; return output_shape;
@ -97,6 +97,9 @@ void GetPadsByPadding(int64_t in_d, int64_t in_h, int64_t in_w, int64_t kernel_d
if (pad_mode == PadMode::VALID) { if (pad_mode == PadMode::VALID) {
(void)pad_list->insert(pad_list->begin(), kPadDims, 0); (void)pad_list->insert(pad_list->begin(), kPadDims, 0);
} else if (pad_mode == PadMode::SAME) { } else if (pad_mode == PadMode::SAME) {
if (stride_d == 0 || stride_h == 0 || stride_w == 0) {
MS_LOG(EXCEPTION) << "stride_d or stride_h or stride_w must be non-zero";
}
int64_t tail_d = in_d % stride_d; int64_t tail_d = in_d % stride_d;
int64_t tail_h = in_h % stride_h; int64_t tail_h = in_h % stride_h;
int64_t tail_w = in_w % stride_w; int64_t tail_w = in_w % stride_w;
@ -130,8 +133,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
int64_t pad_mode = 0; int64_t pad_mode = 0;
bool ceil_mode = false; bool ceil_mode = false;
bool count_include_pad = true; bool count_include_pad = true;
int64_t divisor_override = 0; GetAttrs(primitive, &kernel_size, &strides, &pad_mode, &pad_list, &ceil_mode, &count_include_pad);
GetAttrs(primitive, &kernel_size, &strides, &pad_mode, &pad_list, &ceil_mode, &count_include_pad, &divisor_override);
auto in_d = in_shape[2]; auto in_d = in_shape[2];
auto in_h = in_shape[3]; auto in_h = in_shape[3];
auto in_w = in_shape[4]; auto in_w = in_shape[4];

View File

@ -22,7 +22,7 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
namespace { namespace {
abstract::ShapePtr CosInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr CosInferShape(const std::vector<AbstractBasePtr> &input_args) {
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }
@ -39,7 +39,7 @@ AbstractBasePtr CosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
const int64_t input_num = 1; const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return abstract::MakeAbstract(CosInferShape(primitive, input_args), CosInferType(primitive, input_args)); return abstract::MakeAbstract(CosInferShape(input_args), CosInferType(primitive, input_args));
} }
} // namespace } // namespace
REGISTER_PRIMITIVE_EVAL_IMPL(Cos, prim::kPrimCos, CosInfer, nullptr, true); REGISTER_PRIMITIVE_EVAL_IMPL(Cos, prim::kPrimCos, CosInfer, nullptr, true);

View File

@ -39,7 +39,7 @@ abstract::ShapePtr DiagPartInferShape(const PrimitivePtr &primitive, const std::
for (size_t i = 0; i < length; i++) { for (size_t i = 0; i < length; i++) {
CheckAndConvertUtils::Check("input_shape[i + rank(input_shape) / 2]", input_shape[i + length], kEqual, CheckAndConvertUtils::Check("input_shape[i + rank(input_shape) / 2]", input_shape[i + length], kEqual,
"input_shape[i]", input_shape[i], op_name, ValueError); "input_shape[i]", input_shape[i], op_name, ValueError);
out_shape.emplace_back(input_shape[i]); (void)out_shape.emplace_back(input_shape[i]);
} }
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} }

View File

@ -27,7 +27,8 @@ namespace {
abstract::ShapePtr ErfinvInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr ErfinvInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input_x numbers", input_args.size(), kEqual, 1, prim_name); const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInteger("input_x numbers", input_args.size(), kEqual, input_num, prim_name);
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
@ -39,13 +40,14 @@ abstract::ShapePtr ErfinvInferShape(const PrimitivePtr &primitive, const std::ve
TypePtr ErfinvInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr ErfinvInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name(); auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("input_x number", input_args.size(), kEqual, 1, op_name); const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInteger("input_x number", input_args.size(), kEqual, input_num, op_name);
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto infer_type = input_args[0]->BuildType(); auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("input_x", infer_type, valid_types, prim->name()); (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", infer_type, valid_types, prim->name());
return infer_type; return infer_type;
} }
} // namespace } // namespace

View File

@ -39,16 +39,16 @@ abstract::ShapePtr IndexAddInferShape(const PrimitivePtr &primitive, const std::
CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeNeither, {-x_rank - 1, x_rank}, prim_name); CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeNeither, {-x_rank - 1, x_rank}, prim_name);
auto idx_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto idx_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto idx_rank = SizeToLong(idx_shape.size()); auto idx_rank = SizeToLong(idx_shape.size());
CheckAndConvertUtils::CheckInteger("idx size", idx_rank, kEqual, 1, prim_name); (void)CheckAndConvertUtils::CheckInteger("idx size", idx_rank, kEqual, 1, prim_name);
auto axis_rank = axis; auto axis_rank = axis;
if (axis < 0) { if (axis < 0) {
axis_rank = axis + x_rank; axis_rank = axis + x_rank;
} }
CheckAndConvertUtils::Check("size of indices", idx_shape[0], kEqual, "dimension of y[axis]", y_shape[axis_rank], (void)CheckAndConvertUtils::Check("size of indices", idx_shape[0], kEqual, "dimension of y[axis]", y_shape[axis_rank],
prim_name); prim_name);
for (int dim = 0; dim < x_rank; dim = dim + 1) { for (int dim = 0; dim < x_rank; dim = dim + 1) {
if (dim != axis_rank) { if (dim != axis_rank) {
CheckAndConvertUtils::Check("x dim", x_shape[dim], kEqual, "y dim", y_shape[dim], prim_name); (void)CheckAndConvertUtils::Check("x dim", x_shape[dim], kEqual, "y dim", y_shape[dim], prim_name);
} }
} }
return std::make_shared<abstract::Shape>(x_shape); return std::make_shared<abstract::Shape>(x_shape);
@ -66,8 +66,8 @@ TypePtr IndexAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBa
auto var_type = input_args[kInputIndex0]->BuildType(); auto var_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType(); auto indices_type = input_args[kInputIndex1]->BuildType();
auto updates_type = input_args[kInputIndex2]->BuildType(); auto updates_type = input_args[kInputIndex2]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_types, prim->name()); (void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_types, prim->name());
CheckAndConvertUtils::CheckTensorTypeValid("input_y type", updates_type, valid_types, prim->name()); (void)CheckAndConvertUtils::CheckTensorTypeValid("input_y type", updates_type, valid_types, prim->name());
return CheckAndConvertUtils::CheckTensorTypeValid("input_x type", var_type, valid_types, prim->name()); return CheckAndConvertUtils::CheckTensorTypeValid("input_x type", var_type, valid_types, prim->name());
} }
} // namespace } // namespace

View File

@ -41,8 +41,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto weight_shape = weight_shape_map[kShape]; auto weight_shape = weight_shape_map[kShape];
auto broadcast_shape = CalBroadCastShape(start_shape, end_shape, op_name, "start", "end"); auto broadcast_shape = CalBroadCastShape(start_shape, end_shape, op_name, "start", "end");
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) { if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
CalBroadCastShape(start_shape, weight_shape, op_name, "start", "weight"); (void)CalBroadCastShape(start_shape, weight_shape, op_name, "start", "weight");
CalBroadCastShape(end_shape, weight_shape, op_name, "end", "weight"); (void)CalBroadCastShape(end_shape, weight_shape, op_name, "end", "weight");
broadcast_shape = CalBroadCastShape(broadcast_shape, weight_shape, op_name); broadcast_shape = CalBroadCastShape(broadcast_shape, weight_shape, op_name);
} }
return std::make_shared<abstract::Shape>(broadcast_shape); return std::make_shared<abstract::Shape>(broadcast_shape);
@ -56,8 +56,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
const int64_t input_num = 3; const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name); (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name);
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
types.emplace("start", input_args[0]->BuildType()); (void)types.emplace("start", input_args[0]->BuildType());
types.emplace("end", input_args[1]->BuildType()); (void)types.emplace("end", input_args[1]->BuildType());
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) { if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
(void)types.emplace("weight", input_args[kInputIndex2]->BuildType()); (void)types.emplace("weight", input_args[kInputIndex2]->BuildType());
} else { } else {

View File

@ -55,7 +55,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
auto op_name = prim->name(); auto op_name = prim->name();
const int64_t input_num = 3; const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, op_name); (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, input_num, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("mask", input_args[1]->BuildType(), {kBool}, op_name); (void)CheckAndConvertUtils::CheckTensorTypeValid("mask", input_args[1]->BuildType(), {kBool}, op_name);
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) { if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
(void)types.emplace("input", input_args[kInputIndex0]->BuildType()); (void)types.emplace("input", input_args[kInputIndex0]->BuildType());

View File

@ -99,8 +99,9 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
} }
// check empty input // check empty input
auto send_rank_ids = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSendRankIds)); auto send_rank_ids = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSendRankIds));
const int64_t input_num = 0;
if (send_rank_ids.empty()) { if (send_rank_ids.empty()) {
(void)CheckAndConvertUtils::CheckInteger("input_numbers", input_args.size(), kEqual, 0, prim_name); (void)CheckAndConvertUtils::CheckInteger("input_numbers", input_args.size(), kEqual, input_num, prim_name);
return; return;
} }
// check input shape & attr send shape // check input shape & attr send shape

View File

@ -31,7 +31,7 @@ abstract::ShapePtr OnesInferShape(const PrimitivePtr &primitive, const std::vect
// check // check
auto shape_value = input_args[0]->BuildValue(); auto shape_value = input_args[0]->BuildValue();
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name); std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name);
CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name); (void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} }

View File

@ -28,7 +28,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name); const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, input_num, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto x_rank = SizeToLong(x_shape.size()); auto x_rank = SizeToLong(x_shape.size());

View File

@ -43,19 +43,21 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape]; auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape]; auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape];
// Args lr must be scalar // Args lr must be scalar
(void)CheckAndConvertUtils::CheckInteger("size of lr_shape", lr_shape.size(), kEqual, 0, primitive->name()); const int64_t input_num = 0;
(void)CheckAndConvertUtils::CheckInteger("size of lr_shape", lr_shape.size(), kEqual, input_num, primitive->name());
// Shape of var、ms、mom、grad must be same // Shape of var、ms、mom、grad must be same
std::map<std::string, ShapeVector> same_shape_args_map; std::map<std::string, ShapeVector> same_shape_args_map;
same_shape_args_map.insert({"shape of ms ", ms_shape}); (void)same_shape_args_map.insert({"shape of ms ", ms_shape});
same_shape_args_map.insert({"shape of mom ", mom_shape}); (void)same_shape_args_map.insert({"shape of mom ", mom_shape});
same_shape_args_map.insert({"shape of grad ", grad_shape}); (void)same_shape_args_map.insert({"shape of grad ", grad_shape});
for (auto &elem : same_shape_args_map) { for (auto &elem : same_shape_args_map) {
CheckAndConvertUtils::Check(elem.first, elem.second, kEqual, "var shape", var_shape, prim_name); CheckAndConvertUtils::Check(elem.first, elem.second, kEqual, "var shape", var_shape, prim_name);
} }
// Indices must be rank 1 // Indices must be rank 1
(void)CheckAndConvertUtils::CheckInteger("indices dim", indices_shape.size(), kEqual, 1, prim_name); const int64_t input_num1 = 1;
(void)CheckAndConvertUtils::CheckInteger("indices dim", indices_shape.size(), kEqual, input_num1, prim_name);
// Dimension of var must be equal or greater than 1 // Dimension of var must be equal or greater than 1
(void)CheckAndConvertUtils::CheckInteger("dimension of var", var_shape.size(), kGreaterEqual, 1, prim_name); (void)CheckAndConvertUtils::CheckInteger("dimension of var", var_shape.size(), kGreaterEqual, input_num1, prim_name);
// Indices shape must be equal to the first dimension of var // Indices shape must be equal to the first dimension of var
CheckAndConvertUtils::Check("indices shape", indices_shape[0], kEqual, "the first dimension of var", var_shape[0], CheckAndConvertUtils::Check("indices shape", indices_shape[0], kEqual, "the first dimension of var", var_shape[0],
prim_name); prim_name);
@ -79,18 +81,18 @@ TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr>
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// Args ms、mom、grad must have the same type as var // Args ms、mom、grad must have the same type as var
std::map<std::string, TypePtr> args; std::map<std::string, TypePtr> args;
args.insert({"var", var_type}); (void)args.insert({"var", var_type});
args.insert({"ms", ms_type}); (void)args.insert({"ms", ms_type});
args.insert({"mom", mom_type}); (void)args.insert({"mom", mom_type});
args.insert({"grad", grad_type}); (void)args.insert({"grad", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
// Args lr must be a scalar type // Args lr must be a scalar type
std::map<std::string, TypePtr> args2; std::map<std::string, TypePtr> args2;
args2.insert({"lr", lr_type}); (void)args2.insert({"lr", lr_type});
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name);
// Check indices type // Check indices type
std::map<std::string, TypePtr> args3; std::map<std::string, TypePtr> args3;
args3.insert({"indices", indices_type}); (void)args3.insert({"indices", indices_type});
const std::set<TypePtr> valid_types1 = {kInt32, kInt64}; const std::set<TypePtr> valid_types1 = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types1, prim_name); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types1, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, ms_type, mom_type}); return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, ms_type, mom_type});

View File

@ -36,42 +36,42 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
if (split_dim < 0) { if (split_dim < 0) {
split_dim += x_rank; split_dim += x_rank;
} }
auto shape_of_split_dim = x_shape[split_dim]; auto shape_of_split_dim = x_shape[LongToSize(split_dim)];
auto num_split = GetValue<int64_t>(primitive->GetAttr("num_split")); auto num_split = GetValue<int64_t>(primitive->GetAttr("num_split"));
CheckAndConvertUtils::CheckInteger("num_split", num_split, kGreaterEqual, 1, prim_name); (void)CheckAndConvertUtils::CheckInteger("num_split", num_split, kGreaterEqual, 1, prim_name);
auto size_splits = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSizeSplits)); auto size_splits = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSizeSplits));
CheckAndConvertUtils::Check("num_split", num_split, kEqual, "rank of size_splits", SizeToLong(size_splits.size()), CheckAndConvertUtils::Check("num_split", num_split, kEqual, "rank of size_splits", SizeToLong(size_splits.size()),
prim_name); prim_name);
auto default_idx = std::find(size_splits.begin(), size_splits.end(), -1); auto default_idx = std::find(size_splits.begin(), size_splits.end(), -1);
if (default_idx == size_splits.end()) { if (default_idx == size_splits.end()) {
int sum_of_size_splits = 0; int64_t sum_of_size_splits = 0;
for (int64_t i = 0; i < num_split; i++) { for (int64_t i = 0; i < num_split; i++) {
CheckAndConvertUtils::CheckInRange("elements of size_splits", size_splits[i], kIncludeBoth, (void)CheckAndConvertUtils::CheckInRange("elements of size_splits", size_splits[i], kIncludeBoth,
{0, shape_of_split_dim}, prim_name); {0, shape_of_split_dim}, prim_name);
sum_of_size_splits += size_splits[i]; sum_of_size_splits += size_splits[i];
} }
CheckAndConvertUtils::Check("sum of size_splits", sum_of_size_splits, kEqual, "dimension of value along split_dim", CheckAndConvertUtils::Check("sum of size_splits", sum_of_size_splits, kEqual, "dimension of value along split_dim",
shape_of_split_dim, prim_name); shape_of_split_dim, prim_name);
} else { } else {
size_splits.erase(default_idx); (void)size_splits.erase(default_idx);
auto excessive_default_idx = std::find(size_splits.begin(), size_splits.end(), -1); auto excessive_default_idx = std::find(size_splits.begin(), size_splits.end(), -1);
if (excessive_default_idx != size_splits.end()) { if (excessive_default_idx != size_splits.end()) {
MS_EXCEPTION(ValueError) << "Got more than one default value -1 in size_splits."; MS_EXCEPTION(ValueError) << "Got more than one default value -1 in size_splits.";
} else { } else {
int sum_of_size_splits = 0; int64_t sum_of_size_splits = 0;
for (int64_t i = 0; i < num_split - 1; i++) { for (int64_t i = 0; i < num_split - 1; i++) {
CheckAndConvertUtils::CheckInRange("elements of size_splits", size_splits[i], kIncludeBoth, (void)CheckAndConvertUtils::CheckInRange("elements of size_splits", size_splits[i], kIncludeBoth,
{0, shape_of_split_dim}, prim_name); {0, shape_of_split_dim}, prim_name);
sum_of_size_splits += size_splits[i]; sum_of_size_splits += size_splits[i];
} }
auto default_value = shape_of_split_dim - sum_of_size_splits; auto default_value = shape_of_split_dim - sum_of_size_splits;
size_splits.insert(default_idx, default_value); (void)size_splits.insert(default_idx, default_value);
} }
} }
std::vector<abstract::BaseShapePtr> shape_tuple; std::vector<abstract::BaseShapePtr> shape_tuple;
for (int64_t i = 0; i < num_split; i++) { for (int64_t i = 0; i < num_split; i++) {
auto shape = x_shape; auto shape = x_shape;
shape[split_dim] = size_splits[i]; shape[split_dim] = size_splits[LongToSize(i)];
abstract::ShapePtr out_shape = std::make_shared<abstract::Shape>(shape); abstract::ShapePtr out_shape = std::make_shared<abstract::Shape>(shape);
shape_tuple.push_back(out_shape); shape_tuple.push_back(out_shape);
} }

View File

@ -27,14 +27,12 @@ void ImpleSquare(void *origin, void *target, size_t size) {
MS_EXCEPTION_IF_NULL(target); MS_EXCEPTION_IF_NULL(target);
auto origin_data = reinterpret_cast<T *>(origin); auto origin_data = reinterpret_cast<T *>(origin);
auto target_data = reinterpret_cast<T *>(target); auto target_data = reinterpret_cast<T *>(target);
MS_EXCEPTION_IF_NULL(origin_data);
MS_EXCEPTION_IF_NULL(target_data);
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
target_data[i] = origin_data[i] * origin_data[i]; target_data[i] = origin_data[i] * origin_data[i];
} }
} }
abstract::ShapePtr SquareInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr SquareInferShape(const std::vector<AbstractBasePtr> &input_args) {
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape()); auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
auto in_shape = shape_map[kShape]; auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape]; auto min_shape = shape_map[kMinShape];
@ -54,7 +52,7 @@ AbstractBasePtr SquareInfer(const abstract::AnalysisEnginePtr &, const Primitive
const int64_t input_num = 1; const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return abstract::MakeAbstract(SquareInferShape(primitive, input_args), SquareInferType(primitive, input_args)); return abstract::MakeAbstract(SquareInferShape(input_args), SquareInferType(primitive, input_args));
} }
ValuePtr SquareInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { ValuePtr SquareInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
@ -74,53 +72,53 @@ ValuePtr SquareInferValue(const PrimitivePtr &prim, const std::vector<AbstractBa
auto data_size = x_tensor->DataSize(); auto data_size = x_tensor->DataSize();
auto dtype = x_tensor->data_type(); auto dtype = x_tensor->data_type();
auto shape = SquareInferShape(prim, input_args)->shape(); auto shape = SquareInferShape(input_args)->shape();
auto result_tensor = std::make_shared<tensor::Tensor>(dtype, shape); // same shape and dtype auto result_tensor = std::make_shared<tensor::Tensor>(dtype, shape); // same shape and dtype
auto x_datac = x_tensor->data_c(); auto x_datac = x_tensor->data_c();
auto result_datac = result_tensor->data_c(); auto result_datac = result_tensor->data_c();
switch (dtype) { switch (dtype) {
case kNumberTypeInt8: { case kNumberTypeInt8: {
ImpleSquare<int8_t>(x_datac, result_datac, data_size); ImpleSquare<int8_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeInt16: { case kNumberTypeInt16: {
ImpleSquare<int16_t>(x_datac, result_datac, data_size); ImpleSquare<int16_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeInt32: { case kNumberTypeInt32: {
ImpleSquare<int32_t>(x_datac, result_datac, data_size); ImpleSquare<int32_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeInt64: { case kNumberTypeInt64: {
ImpleSquare<int64_t>(x_datac, result_datac, data_size); ImpleSquare<int64_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt8: { case kNumberTypeUInt8: {
ImpleSquare<uint8_t>(x_datac, result_datac, data_size); ImpleSquare<uint8_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt16: { case kNumberTypeUInt16: {
ImpleSquare<uint16_t>(x_datac, result_datac, data_size); ImpleSquare<uint16_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt32: { case kNumberTypeUInt32: {
ImpleSquare<uint32_t>(x_datac, result_datac, data_size); ImpleSquare<uint32_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeUInt64: { case kNumberTypeUInt64: {
ImpleSquare<uint64_t>(x_datac, result_datac, data_size); ImpleSquare<uint64_t>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeFloat16: { case kNumberTypeFloat16: {
ImpleSquare<float16>(x_datac, result_datac, data_size); ImpleSquare<float16>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeFloat32: { case kNumberTypeFloat32: {
ImpleSquare<float>(x_datac, result_datac, data_size); ImpleSquare<float>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
case kNumberTypeFloat64: { case kNumberTypeFloat64: {
ImpleSquare<double>(x_datac, result_datac, data_size); ImpleSquare<double>(x_datac, result_datac, IntToSize(data_size));
break; break;
} }
default: { default: {