static code clean

This commit is contained in:
chenfei 2022-09-19 16:07:01 +08:00
parent e7a1be0416
commit 6f45ecd676
5 changed files with 16 additions and 20 deletions

View File

@ -24,6 +24,7 @@
#include <set>
#include "utils/hash_map.h"
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
@ -264,7 +265,7 @@ class ChoicePartialEliminater : public AnfVisitor {
MS_LOG(EXCEPTION) << "Can't find parameter of arg:" << arg->DebugString();
}
std::vector<AnfNodePtr> GetFuncGraphNewParameters(
static std::vector<AnfNodePtr> GetFuncGraphNewParameters(
const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &new_args,
const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map) {
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -17,10 +17,6 @@
#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_PYBIND_API_API_REGISTER_H_
#define MINDSPORE_CCSRC_INCLUDE_COMMON_PYBIND_API_API_REGISTER_H_
#include <map>
#include <string>
#include <memory>
#include <functional>
#include <vector>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

View File

@ -17,17 +17,12 @@
#ifndef MINDSPORE_CORE_IR_DTYPE_H_
#define MINDSPORE_CORE_IR_DTYPE_H_
#include <cstddef>
#include <iostream>
#include <initializer_list>
#include <memory>
#include <utility>
#include <sstream>
#include <string>
#include <vector>
#include <unordered_map>
#include <type_traits>
#include <algorithm>
#include "base/base.h"
#include "ir/named.h"

View File

@ -52,7 +52,7 @@ class ArgMaxAbsInfer : public abstract::OpInferBase {
axis = axis < 0 ? axis + SizeToLong(x_rank) : axis;
auto out_shape_vector = shape_vector;
(void)out_shape_vector.erase(out_shape_vector.cbegin() + LongToSize(axis));
(void)out_shape_vector.erase(out_shape_vector.cbegin() + axis);
return std::make_shared<abstract::Shape>(out_shape_vector);
}

View File

@ -57,6 +57,18 @@ bool CheckScalarOrTensor(ShapeVector input) {
return flag;
}
void SelectInferShapeCheck(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
const std::vector<int64_t> &cond_shape, size_t shape_size) {
for (size_t i = 0; i < shape_size; i++) {
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
MS_EXCEPTION(ValueError)
<< "For 'Select', shape of tensor condition, x and y must be the same. But got condition shape: " << cond_shape
<< ", x shape: " << x_shape << ", y shape: " << y_shape << ".";
}
}
}
abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &, const std::vector<AbstractBasePtr> &input_args) {
auto cond_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kSelectCondIndex]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kSelectXIndex]->BuildShape())[kShape];
@ -72,16 +84,8 @@ abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &, const std::vector<
MS_EXCEPTION(ValueError)
<< "For 'Select', shape size of tensor condition, x and y must be equal. But got condition size: "
<< cond_shape_size << ", x size: " << x_shape_size << ", y size: " << y_shape_size << ".";
} else {
for (size_t i = 0; i < x_shape_size; i++) {
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
MS_EXCEPTION(ValueError)
<< "For 'Select', shape of tensor condition, x and y must be the same. But got condition shape: "
<< cond_shape << ", x shape: " << x_shape << ", y shape: " << y_shape << ".";
}
}
}
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
} else {
if (!(CheckScalarOrTensor(cond_shape) && CheckScalarOrTensor(x_shape) && CheckScalarOrTensor(y_shape))) {
MS_EXCEPTION(ValueError) << "For 'Select', when any of cond, x, y is of scalar type, "