forked from mindspore-Ecosystem/mindspore
static code clean
This commit is contained in:
parent
e7a1be0416
commit
6f45ecd676
|
@ -24,6 +24,7 @@
|
|||
#include <set>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
|
@ -264,7 +265,7 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
MS_LOG(EXCEPTION) << "Can't find parameter of arg:" << arg->DebugString();
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> GetFuncGraphNewParameters(
|
||||
static std::vector<AnfNodePtr> GetFuncGraphNewParameters(
|
||||
const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &new_args,
|
||||
const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -17,10 +17,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_PYBIND_API_API_REGISTER_H_
|
||||
#define MINDSPORE_CCSRC_INCLUDE_COMMON_PYBIND_API_API_REGISTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
|
|
@ -17,17 +17,12 @@
|
|||
#ifndef MINDSPORE_CORE_IR_DTYPE_H_
|
||||
#define MINDSPORE_CORE_IR_DTYPE_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <type_traits>
|
||||
#include <algorithm>
|
||||
#include "base/base.h"
|
||||
#include "ir/named.h"
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class ArgMaxAbsInfer : public abstract::OpInferBase {
|
|||
axis = axis < 0 ? axis + SizeToLong(x_rank) : axis;
|
||||
|
||||
auto out_shape_vector = shape_vector;
|
||||
(void)out_shape_vector.erase(out_shape_vector.cbegin() + LongToSize(axis));
|
||||
(void)out_shape_vector.erase(out_shape_vector.cbegin() + axis);
|
||||
return std::make_shared<abstract::Shape>(out_shape_vector);
|
||||
}
|
||||
|
||||
|
|
|
@ -57,6 +57,18 @@ bool CheckScalarOrTensor(ShapeVector input) {
|
|||
return flag;
|
||||
}
|
||||
|
||||
void SelectInferShapeCheck(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
|
||||
const std::vector<int64_t> &cond_shape, size_t shape_size) {
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
|
||||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', shape of tensor condition, x and y must be the same. But got condition shape: " << cond_shape
|
||||
<< ", x shape: " << x_shape << ", y shape: " << y_shape << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto cond_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kSelectCondIndex]->BuildShape())[kShape];
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kSelectXIndex]->BuildShape())[kShape];
|
||||
|
@ -72,16 +84,8 @@ abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &, const std::vector<
|
|||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', shape size of tensor condition, x and y must be equal. But got condition size: "
|
||||
<< cond_shape_size << ", x size: " << x_shape_size << ", y size: " << y_shape_size << ".";
|
||||
} else {
|
||||
for (size_t i = 0; i < x_shape_size; i++) {
|
||||
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
|
||||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', shape of tensor condition, x and y must be the same. But got condition shape: "
|
||||
<< cond_shape << ", x shape: " << x_shape << ", y shape: " << y_shape << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
|
||||
} else {
|
||||
if (!(CheckScalarOrTensor(cond_shape) && CheckScalarOrTensor(x_shape) && CheckScalarOrTensor(y_shape))) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Select', when any of cond, x, y is of scalar type, "
|
||||
|
|
Loading…
Reference in New Issue