forked from mindspore-Ecosystem/mindspore
!21892 clean pclint warning for r1.3
Merge pull request !21892 from wangnan39/pclint_clean_ops
This commit is contained in:
commit
57fa49c7af
|
@ -35,7 +35,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("input_x", input_args[0]->BuildType());
|
||||
(void)types.emplace("input_x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -42,8 +42,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("Add infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -73,17 +73,17 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
|
||||
prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("element_0", elements[0]->BuildType());
|
||||
(void)types.emplace("element_0", elements[0]->BuildType());
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
if (elements[i]->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return elements[0]->BuildType();
|
||||
}
|
||||
std::string element_i = "element_" + std::to_string(i);
|
||||
types.emplace(element_i, elements[i]->BuildType());
|
||||
(void)types.emplace(element_i, elements[i]->BuildType());
|
||||
}
|
||||
std::set<TypePtr> valid_types = common_valid_types;
|
||||
valid_types.insert(kBool);
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
return elements[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -76,8 +76,8 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
auto g_type = input_args[3]->BuildType();
|
||||
auto m_type = input_args[4]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.insert({"l_type", l_type});
|
||||
args.insert({"g_type", g_type});
|
||||
|
|
|
@ -33,7 +33,7 @@ AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const Primi
|
|||
(void)CheckAndConvertUtils::CheckInteger(
|
||||
"Assign infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(args_spec_list)), kEqual, 2, prim_name);
|
||||
auto check_types = common_valid_types;
|
||||
check_types.emplace(kBool);
|
||||
(void)check_types.emplace(kBool);
|
||||
auto variable_type = args_spec_list[0]->BuildType();
|
||||
auto value_type = args_spec_list[1]->BuildType();
|
||||
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(std::map<std::string, TypePtr>{{"value", value_type}}, check_types,
|
||||
|
@ -41,7 +41,7 @@ AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const Primi
|
|||
if (variable_type->isa<RefKeyType>()) {
|
||||
return args_spec_list[1]->Broaden();
|
||||
}
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("variable", variable_type, check_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("variable", variable_type, check_types, prim_name);
|
||||
return args_spec_list[0];
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign, nullptr, true);
|
||||
|
|
|
@ -32,8 +32,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
// check_scalar_or_tensor_types_same
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
|
||||
}
|
||||
|
|
|
@ -79,11 +79,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
auto x_type = input_args[i]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_type, valid_x_type, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_type, valid_x_type, prim_name);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("input_x", input_args[0]->BuildType());
|
||||
types.emplace("bias", input_args[1]->BuildType());
|
||||
(void)types.emplace("input_x", input_args[0]->BuildType());
|
||||
(void)types.emplace("bias", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -57,12 +57,12 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<A
|
|||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x_shape", input_args[0]->BuildType());
|
||||
types.emplace("y_shape", input_args[1]->BuildType());
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("y_shape", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (input_args[3]->BuildType() != nullptr) {
|
||||
types.emplace("x_shape", input_args[0]->BuildType());
|
||||
types.emplace("weight_shape", input_args[2]->BuildType());
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("weight_shape", input_args[2]->BuildType());
|
||||
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
return infer_type;
|
||||
|
|
|
@ -58,7 +58,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
auto out_type = input_args[i]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
output_types.push_back(out_type);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("index_type", out_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", out_type, valid_types, prim_name);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(x_type, in_shape);
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
}
|
||||
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
std::set<TypePtr> template_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
|
||||
return x_dtype->element();
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -52,7 +52,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
auto axis = axis_temp < 0 ? LongToSize(axis_temp) + element0_rank : LongToSize(axis_temp);
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("element0", element0->BuildType());
|
||||
(void)types.emplace("element0", element0->BuildType());
|
||||
int64_t all_shp = element0_shape[axis];
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
|
@ -65,7 +65,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
}
|
||||
}
|
||||
all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis];
|
||||
types.emplace(elementi, elements[i]->BuildType());
|
||||
(void)types.emplace(elementi, elements[i]->BuildType());
|
||||
}
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name);
|
||||
auto ret_shape = element0_shape;
|
||||
|
|
|
@ -320,9 +320,9 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
primitive->name());
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
|
||||
return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer, nullptr, true);
|
||||
|
|
|
@ -36,7 +36,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -57,8 +57,8 @@ TypePtr CTCLossV2GradInferType(const PrimitivePtr &primitive, const std::vector<
|
|||
std::map<std::string, TypePtr> types;
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
types.emplace("grad_out", input_args[0]->BuildType());
|
||||
types.emplace("log_probs", input_args[1]->BuildType());
|
||||
(void)types.emplace("grad_out", input_args[0]->BuildType());
|
||||
(void)types.emplace("log_probs", input_args[1]->BuildType());
|
||||
auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, name);
|
||||
return out_type;
|
||||
}
|
||||
|
|
|
@ -82,9 +82,9 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
|
||||
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
||||
|
|
|
@ -37,8 +37,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -46,8 +46,8 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kGreaterEqual, 1, prim_name);
|
||||
std::vector<int64_t> out_shape;
|
||||
out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
|
||||
out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
|
||||
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
|
||||
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
|
||||
auto infer_shape = std::make_shared<abstract::Shape>(out_shape);
|
||||
|
||||
// Infer type
|
||||
|
|
|
@ -79,7 +79,8 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
|
||||
if (keep_prop->isa<abstract::AbstractTensor>()) {
|
||||
const std::set<TypePtr> keep_prop_valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("keep prop", keep_prop->BuildType(), keep_prop_valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("keep prop", keep_prop->BuildType(), keep_prop_valid_types,
|
||||
op_name);
|
||||
if (keep_prop_value->isa<tensor::Tensor>()) {
|
||||
auto keep_prop_tensor = keep_prop_value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(keep_prop_tensor);
|
||||
|
@ -106,7 +107,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
MS_EXCEPTION(TypeError) << "The DropoutDoMask's keep_prop input must be a float number or tensor.";
|
||||
}
|
||||
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[1]->BuildType(), {kUInt8}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[1]->BuildType(), {kUInt8}, op_name);
|
||||
const std::set<TypePtr> input_valid_types = {kFloat16, kFloat32, kInt32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), input_valid_types, op_name);
|
||||
}
|
||||
|
|
|
@ -138,7 +138,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat16};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[1]->BuildType(), valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[1]->BuildType(), valid_types, op_name);
|
||||
return kUInt8;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -44,7 +44,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -53,21 +53,21 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const
|
|||
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
ShapeVector shape;
|
||||
shape.insert(shape.end(), indices_shp.begin(), indices_shp.end());
|
||||
shape.insert(shape.end(), params_shp.begin() + 1, params_shp.end());
|
||||
(void)shape.insert(shape.end(), indices_shp.begin(), indices_shp.end());
|
||||
(void)shape.insert(shape.end(), params_shp.begin() + 1, params_shp.end());
|
||||
auto indices_max_shape = indices->shape()->max_shape();
|
||||
ShapeVector max_shape;
|
||||
if (!indices_max_shape.empty()) {
|
||||
max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end());
|
||||
max_shape.insert(max_shape.end(), params_shp.begin() + 1, params_shp.end());
|
||||
(void)max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end());
|
||||
(void)max_shape.insert(max_shape.end(), params_shp.begin() + 1, params_shp.end());
|
||||
} else {
|
||||
max_shape = shape;
|
||||
}
|
||||
auto indices_min_shape = indices->shape()->min_shape();
|
||||
ShapeVector min_shape;
|
||||
if (!indices_min_shape.empty()) {
|
||||
min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end());
|
||||
min_shape.insert(min_shape.end(), params_shp.begin() + 1, params_shp.end());
|
||||
(void)min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end());
|
||||
(void)min_shape.insert(min_shape.end(), params_shp.begin() + 1, params_shp.end());
|
||||
} else {
|
||||
min_shape = shape;
|
||||
}
|
||||
|
|
|
@ -48,8 +48,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -51,9 +51,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, prim->name());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -44,13 +44,13 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
dim_val += x_shape.size() + 1;
|
||||
}
|
||||
auto out_shape = x_shape;
|
||||
out_shape.insert(out_shape.begin() + dim_val, 1, 1);
|
||||
(void)out_shape.insert(out_shape.begin() + dim_val, 1, 1);
|
||||
|
||||
// Infer type
|
||||
const int64_t x_index = 0;
|
||||
auto x_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim_name);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name);
|
||||
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameExpandDims, ExpandDims);
|
||||
|
|
|
@ -51,9 +51,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("min", input_args[1]->BuildType());
|
||||
types.emplace("max", input_args[2]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("min", input_args[1]->BuildType());
|
||||
(void)types.emplace("max", input_args[2]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -42,7 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
}
|
||||
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
const std::set<TypePtr> valid_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
return infer_type;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -34,7 +34,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -47,8 +47,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -47,8 +47,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -33,10 +33,10 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
|
||||
// check
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("params_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name);
|
||||
CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[1]->BuildType(), int_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("axis_type", input_args[2]->BuildType(), int_types, op_name);
|
||||
|
||||
bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
|
||||
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
|
||||
|
@ -75,7 +75,7 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
}
|
||||
auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector {
|
||||
ShapeVector out_vec;
|
||||
std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
|
||||
(void)std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
|
||||
copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec));
|
||||
copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec));
|
||||
return out_vec;
|
||||
|
|
|
@ -66,8 +66,8 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name);
|
||||
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||
|
|
|
@ -54,7 +54,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("input_x", input_args[0]->BuildType());
|
||||
(void)types.emplace("input_x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -48,7 +48,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -40,12 +40,12 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
|
|||
TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x_shape", input_args[0]->BuildType());
|
||||
types.emplace("y_shape", input_args[1]->BuildType());
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("y_shape", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (input_args[3]->BuildType() != nullptr) {
|
||||
types.emplace("x_shape", input_args[0]->BuildType());
|
||||
types.emplace("weight_shape", input_args[2]->BuildType());
|
||||
(void)types.emplace("x_shape", input_args[0]->BuildType());
|
||||
(void)types.emplace("weight_shape", input_args[2]->BuildType());
|
||||
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
return infer_type;
|
||||
|
|
|
@ -39,8 +39,8 @@ TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vecto
|
|||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("doutput", input_args[0]->BuildType());
|
||||
types.emplace("x", input_args[1]->BuildType());
|
||||
(void)types.emplace("doutput", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[1]->BuildType());
|
||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||
}
|
||||
|
|
|
@ -82,8 +82,8 @@ TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector
|
|||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("doutput", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
(void)types.emplace("doutput", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
abstract::CheckShapeSame(prim_name, out, dout);
|
||||
(void)abstract::CheckShapeSame(prim_name, out, dout);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
|
@ -51,7 +51,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
abstract::CheckDtypeSame(prim_name, out, dout);
|
||||
(void)abstract::CheckDtypeSame(prim_name, out, dout);
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
if (!x_type->isa<TensorType>()) {
|
||||
|
|
|
@ -34,7 +34,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -36,9 +36,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return kBool;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -37,8 +37,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -52,9 +52,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -38,8 +38,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -36,8 +36,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -75,11 +75,8 @@ void LRN::Init(const int64_t depth_radius, const float bias, const float alpha,
|
|||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(in_shape.size()), kEqual, 4, prim_name);
|
||||
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
@ -90,15 +87,17 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_nums = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_nums, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLRN, LRN);
|
||||
} // namespace ops
|
||||
|
|
|
@ -98,8 +98,8 @@ TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -34,8 +34,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -40,8 +40,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -44,8 +44,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("Mul infer", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -52,9 +52,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -45,8 +45,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -39,15 +39,15 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
(void)CheckAndConvertUtils::CheckInteger("depth value", depth_val, kGreaterEqual, 0, op_name);
|
||||
if (min_shape.size() == 0 || max_shape.size() == 0) {
|
||||
if (axis >= 0) {
|
||||
in_shape.insert(in_shape.begin() + axis, depth_val);
|
||||
(void)in_shape.insert(in_shape.begin() + axis, depth_val);
|
||||
} else {
|
||||
in_shape.push_back(depth_val);
|
||||
}
|
||||
} else {
|
||||
if (axis >= 0) {
|
||||
in_shape.insert(in_shape.begin() + axis, depth_val);
|
||||
min_shape.insert(min_shape.begin() + axis, depth_val);
|
||||
max_shape.insert(max_shape.begin() + axis, depth_val);
|
||||
(void)in_shape.insert(in_shape.begin() + axis, depth_val);
|
||||
(void)min_shape.insert(min_shape.begin() + axis, depth_val);
|
||||
(void)max_shape.insert(max_shape.begin() + axis, depth_val);
|
||||
} else {
|
||||
in_shape.push_back(depth_val);
|
||||
min_shape.push_back(depth_val);
|
||||
|
@ -59,8 +59,9 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
|
||||
TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name);
|
||||
CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64},
|
||||
op_name);
|
||||
std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()},
|
||||
{"off_dtype", input_args[3]->BuildType()}};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
|
||||
|
|
|
@ -35,9 +35,9 @@ std::vector<int64_t> CalBroadCastShape(std::vector<int64_t> x_shape, std::vector
|
|||
auto length = x_length < y_length ? x_length : y_length;
|
||||
std::vector<int64_t> broadcast_shape;
|
||||
if (x_length == length) {
|
||||
std::copy(y_shape.begin(), y_shape.end() - SizeToLong(length), std::back_inserter(broadcast_shape));
|
||||
(void)std::copy(y_shape.begin(), y_shape.end() - SizeToLong(length), std::back_inserter(broadcast_shape));
|
||||
} else {
|
||||
std::copy(x_shape.begin(), x_shape.end() - SizeToLong(length), std::back_inserter(broadcast_shape));
|
||||
(void)std::copy(x_shape.begin(), x_shape.end() - SizeToLong(length), std::back_inserter(broadcast_shape));
|
||||
}
|
||||
for (int64_t i = -length; i < 0; i++) {
|
||||
if (x_shape[x_length + i] == 1) {
|
||||
|
|
|
@ -42,7 +42,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
|
|||
"shape in input can't pack with first element.";
|
||||
}
|
||||
}
|
||||
output_shape.insert(output_shape.begin() + axis, N);
|
||||
(void)output_shape.insert(output_shape.begin() + axis, N);
|
||||
return output_shape;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -33,8 +33,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -24,7 +24,7 @@ TypePtr RankInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
auto op_name = prim->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto infer_dtype = input_args[0]->BuildType();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name);
|
||||
return kTypeNone;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -44,8 +44,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("RealDiv infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -52,9 +52,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
|
|
|
@ -46,7 +46,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -45,9 +45,9 @@ std::vector<int64_t> GetOutputMaskShape(const std::vector<int64_t> &input_shape,
|
|||
}
|
||||
}
|
||||
if (x_dtype == kUInt8 || x_dtype == kInt8) {
|
||||
mask_shape.insert(mask_shape.end(), 4);
|
||||
(void)mask_shape.insert(mask_shape.end(), 4);
|
||||
} else {
|
||||
mask_shape.insert(mask_shape.end(), 2);
|
||||
(void)mask_shape.insert(mask_shape.end(), 2);
|
||||
}
|
||||
return mask_shape;
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
(void)CheckAndConvertUtils::CheckInteger("input_shape_rank", SizeToLong(input_shape.size()), kEqual, 4, prim_name);
|
||||
std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]};
|
||||
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));
|
||||
out_shape.insert(out_shape.end(), size.begin(), size.end());
|
||||
(void)out_shape.insert(out_shape.end(), size.begin(), size.end());
|
||||
|
||||
// Infer type
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
|
|
|
@ -57,7 +57,7 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
|
|||
input_shape[LongToSize(batch_dim)], prim_name);
|
||||
// infer type
|
||||
std::set<TypePtr> valid_x_types(common_valid_types);
|
||||
valid_x_types.emplace(kBool);
|
||||
(void)valid_x_types.emplace(kBool);
|
||||
const std::set<TypePtr> valid_seq_types = {kInt32, kInt64};
|
||||
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto seq_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
|
|
|
@ -44,8 +44,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
const std::set<TypePtr> update_valid_types = {kTensorType};
|
||||
auto indices_type = input_args[0]->BuildType();
|
||||
auto update_type = input_args[1]->BuildType();
|
||||
CheckAndConvertUtils::CheckTypeValid("update type", update_type, update_valid_types, prim->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("update type", update_type, update_valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_valid_types, prim->name());
|
||||
return input_args[1]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -37,7 +37,7 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
auto in_shape = shape_map[kShape];
|
||||
// infer type
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
AbstractBasePtrList abs_list;
|
||||
(void)std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
|
||||
[](int64_t item) -> std::shared_ptr<abstract::AbstractScalar> {
|
||||
|
@ -53,7 +53,7 @@ ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
(void)CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto inshape = shape_map[kShape];
|
||||
auto value = MakeValue(inshape);
|
||||
|
|
|
@ -48,7 +48,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -33,8 +33,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kInt32, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -32,8 +32,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto len = SizeToLong(in_shape.size());
|
||||
if (axis.empty()) {
|
||||
std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape),
|
||||
[](int64_t value) { return value != 1; });
|
||||
(void)std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape),
|
||||
[](int64_t value) { return value != 1; });
|
||||
} else {
|
||||
for (auto &item : axis) {
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis_or_elememt", item, kIncludeBoth, {-len, len + 1}, op_name);
|
||||
|
|
|
@ -44,7 +44,7 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
|
|||
}
|
||||
}
|
||||
std::vector<int64_t> infer_shape = input_shape;
|
||||
infer_shape.insert(infer_shape.begin() + GetValue<int64_t>(primitive->GetAttr(kAxis)), input_args.size());
|
||||
(void)infer_shape.insert(infer_shape.begin() + GetValue<int64_t>(primitive->GetAttr(kAxis)), input_args.size());
|
||||
|
||||
auto infer_type0 = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
for (size_t i = 1; i < input_args.size(); i++) {
|
||||
|
|
|
@ -44,8 +44,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("Sub infer", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -61,7 +61,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const
|
|||
MS_LOG(ERROR) << "ele_shape->data_c() is nullptr";
|
||||
}
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
input1_shape.insert(input1_shape.begin(), 1);
|
||||
(void)input1_shape.insert(input1_shape.begin(), 1);
|
||||
return std::make_shared<abstract::AbstractTensor>(input_args[0]->BuildType(), input1_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameTensorListStack, TensorListStack);
|
||||
|
|
|
@ -32,7 +32,7 @@ std::vector<int64_t> GetInferShape(const std::vector<int64_t> &input_shape, cons
|
|||
}
|
||||
if (len_sub > 0) {
|
||||
for (int64_t i = 0; i < len_sub; i++) {
|
||||
infer_shape.insert(infer_shape.begin(), 1);
|
||||
(void)infer_shape.insert(infer_shape.begin(), 1);
|
||||
}
|
||||
multiples_w = multiples_v;
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
}
|
||||
std::vector<int64_t> in_shape(p_value);
|
||||
std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](int i) { return x_shape[i]; });
|
||||
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](int i) { return x_shape[i]; });
|
||||
if (!x_min_shape.empty() && !x_max_shape.empty()) {
|
||||
std::vector<int64_t> min_shape;
|
||||
std::vector<int64_t> max_shape;
|
||||
|
|
|
@ -27,8 +27,8 @@ AbstractBasePtr UnpackInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), {TypeIdToType(kObjectTypeTensorType)},
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), {TypeIdToType(kObjectTypeTensorType)},
|
||||
prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
int64_t dim = SizeToLong(x_shape.size());
|
||||
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
|
@ -41,7 +41,7 @@ AbstractBasePtr UnpackInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
(void)CheckAndConvertUtils::CheckInteger("The dimension which to unpack divides output_num", output_valid_check,
|
||||
kEqual, 0, prim_name);
|
||||
std::vector<int64_t> infer_shape(x_shape.begin(), x_shape.begin() + axis);
|
||||
infer_shape.insert(infer_shape.end(), x_shape.begin() + axis + 1, x_shape.end());
|
||||
(void)infer_shape.insert(infer_shape.end(), x_shape.begin() + axis + 1, x_shape.end());
|
||||
AbstractBasePtrList output;
|
||||
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
|
|
|
@ -45,7 +45,8 @@ AbstractBasePtr UnsqueezeInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
auto dim_rank = dims.size();
|
||||
std::vector<int64_t> out_shape;
|
||||
if (dim_rank == 0) {
|
||||
std::copy_if(input_shape.begin(), input_shape.end(), out_shape.begin(), [](const auto item) { return item == 1; });
|
||||
(void)std::copy_if(input_shape.begin(), input_shape.end(), out_shape.begin(),
|
||||
[](const auto item) { return item == 1; });
|
||||
} else {
|
||||
auto sz = input_rank + dim_rank;
|
||||
size_t in_itr = 0;
|
||||
|
|
|
@ -38,7 +38,7 @@ AbstractBasePtr UnstackInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
(void)CheckAndConvertUtils::CheckInteger("The dimension which to unstack divides output_num", output_valid_check,
|
||||
kEqual, 0, prim_name);
|
||||
std::vector<int64_t> infer_shape(x_shape.begin(), x_shape.begin() + axis);
|
||||
infer_shape.insert(infer_shape.end(), x_shape.begin() + axis + 1, x_shape.end());
|
||||
(void)infer_shape.insert(infer_shape.end(), x_shape.begin() + axis + 1, x_shape.end());
|
||||
AbstractBasePtrList output;
|
||||
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
|
|
Loading…
Reference in New Issue