!18257 ops clean code master

Merge pull request !18257 from wanyiming/clean_ops_train
This commit is contained in:
i-robot 2021-06-21 19:42:14 +08:00 committed by Gitee
commit 7c1ea7e3fa
6 changed files with 12 additions and 12 deletions

View File

@ -33,8 +33,8 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
primitive->name());
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
primitive->name());
primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
auto shape_0 = elements[0]->BuildShape();
auto element0_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_0);

View File

@ -33,14 +33,14 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(input_tuple);
auto elements = input_tuple->elements();
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0);
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
@ -56,8 +56,8 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", SizeToLong(elementi_shape.size()), kEqual,
element0_shape.size(), prim_name);
(void)CheckAndConvertUtils::CheckInteger(elementi + " shape rank", SizeToLong(elementi_shape.size()), kEqual,
SizeToLong(element0_shape.size()), prim_name);
for (size_t j = 0; j < element0_rank; ++j) {
if (j != axis && elementi_shape[j] != element0_shape[j]) {
MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element.";

View File

@ -32,9 +32,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kGreaterEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kGreaterEqual, 1, prim_name);
CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name);
CheckAndConvertUtils::CheckInteger("min_shape", SizeToLong(min_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("min_shape", SizeToLong(min_shape.size()), kEqual, 1, prim_name);
int64_t shape_val = 1;
for (size_t i = 0; i < in_shape.size(); i++) {
shape_val = shape_val * in_shape[i];

View File

@ -30,8 +30,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", SizeToLong(input_args.size()), kEqual,
2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", SizeToLong(input_args.size()),
kEqual, 2, prim_name);
// Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];

View File

@ -37,7 +37,7 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);
// Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];

View File

@ -32,7 +32,7 @@ AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("top_k_infer", input_args.size(), kEqual, 2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("top_k_infer", input_args.size(), kEqual, 2, prim_name);
// Infer dtype
auto output1_type = kInt32;