fix code check

This commit is contained in:
huanghui 2021-05-20 20:38:12 +08:00
parent 620ba53725
commit 427337abac
7 changed files with 72 additions and 44 deletions

View File

@ -46,12 +46,12 @@ std::optional<std::string> Common::GetRealPath(const std::string &input_path) {
return std::nullopt;
}
#if defined(SYSTEM_ENV_POSIX)
if (nullptr == realpath(prefix_path.c_str(), real_path)) {
if (realpath(prefix_path.c_str(), real_path) == nullptr) {
MS_LOG(ERROR) << "dir " << prefix_path << " does not exist.";
return std::nullopt;
}
#elif defined(SYSTEM_ENV_WINDOWS)
if (nullptr == _fullpath(real_path, prefix_path.c_str(), PATH_MAX)) {
if (_fullpath(real_path, prefix_path.c_str(), PATH_MAX) == nullptr) {
MS_LOG(ERROR) << "dir " << prefix_path << " does not exist.";
return std::nullopt;
}
@ -273,12 +273,13 @@ std::string Common::AddId(const std::string &filename, const std::string &suffix
static size_t g_id = 0;
std::ostringstream s;
auto i = filename.rfind(suffix);
int spaces = 4;
if (i >= filename.size()) {
s << filename;
s << "_" << std::setfill('0') << std::setw(4) << g_id;
s << "_" << std::setfill('0') << std::setw(spaces) << g_id;
} else {
s << filename.substr(0, i);
s << "_" << std::setfill('0') << std::setw(4) << g_id;
s << "_" << std::setfill('0') << std::setw(spaces) << g_id;
if (i + 1 < filename.size()) {
s << filename.substr(i);
}

View File

@ -48,7 +48,28 @@ DataType InferType(const AnyPtrList &list) {
return DataType::kUnknown;
}
enum OpType { ADD, SUB, MUL, DIV, MOD };
template <typename T>
bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) {
return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
}
template <typename T>
bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) {
return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
}
template <typename T>
bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) {
return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) ||
(x < 0 && y > 0 && (min / y) > x);
}
template <typename T>
bool IsDivOverflow(const T &x, const T &y, const T &max, const T &min) {
return (x == min && static_cast<int64_t>(y) == -1);
}
enum class OpType { ADD, SUB, MUL, DIV, MOD };
template <typename T>
bool IsSignedIntOverflow(T x, T y, OpType opType) {
@ -56,20 +77,19 @@ bool IsSignedIntOverflow(T x, T y, OpType opType) {
auto min = std::numeric_limits<T>::min();
if (opType == OpType::ADD) {
return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
return IsAddOverflow<T>(x, y, max, min);
}
if (opType == OpType::SUB) {
return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
return IsSubOverflow<T>(x, y, max, min);
}
if (opType == OpType::MUL) {
return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) ||
(x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x);
return IsMulOverflow<T>(x, y, max, min);
}
if (opType == OpType::DIV || opType == OpType::MOD) {
return x == min && static_cast<int64_t>(y) == -1;
return IsDivOverflow<T>(x, y, max, min);
}
MS_LOG(EXCEPTION) << "Unsupported operation type.";

View File

@ -199,7 +199,7 @@ AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGrap
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
const std::set<size_t> &write_indices, std::vector<AnfNodePtr> *const op_inputs) {
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; });
@ -244,12 +244,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
}
}
AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> &params_list) {
// args: original inputs
auto &signature = GetSignature(function);
std::size_t sig_size = signature.size();
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list,
const std::string &func_name) {
if (sig_size > 0) {
if (has_var) {
if (sig_size - 1 > args_spec_list.size()) {
@ -260,6 +256,15 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
}
}
}
AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> &params_list) {
// args: original inputs
auto &signature = GetSignature(function);
std::size_t sig_size = signature.size();
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
CheckSigSize(sig_size, has_var, args_spec_list, func_name);
std::vector<AnfNodePtr> op_inputs;
std::set<size_t> write_indices;
std::vector<TypePtr> input_types;
@ -308,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
// process default
ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices);
DoAutoCast(func_name, signature, input_types, func_graph, write_indices, &op_inputs);
return func_graph->NewCNodeInOrder(op_inputs);
}
} // namespace

View File

@ -22,7 +22,7 @@ namespace irpass {
#define UPPER_FLT_LIMIT (FLT_MAX / 2.0)
#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0)
// Define the checking mode
enum ScalarCheckingMode : int64_t { GREATER_EQUAL = 0, LESS };
enum class ScalarCheckingMode : int64_t { GREATER_EQUAL = 0, LESS };
bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) {
auto value_node = node->cast<ValueNodePtr>();
@ -38,7 +38,7 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec
auto scalar = value->cast<ScalarPtr>();
if (scalar != nullptr) {
if (scalar->isa<FloatImm>()) {
if (checking_mode == GREATER_EQUAL) {
if (checking_mode == ScalarCheckingMode::GREATER_EQUAL) {
return GetValue<float>(scalar) >= check_value;
}
return GetValue<float>(scalar) < check_value;
@ -56,7 +56,7 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
if (checking_mode == GREATER_EQUAL) {
if (checking_mode == ScalarCheckingMode::GREATER_EQUAL) {
return data[0] >= check_value;
}
return data[0] < check_value;
@ -66,7 +66,9 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec
}
// check if a value is greater or equal 0.0
bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); }
bool IsNodeScalarPositive(const AnfNodePtr &node) {
return IsNodeScalarTrueWith(node, ScalarCheckingMode::GREATER_EQUAL, 0.0);
}
bool IsCNodePositive(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) {
@ -87,10 +89,14 @@ bool IsCNodePositive(const AnfNodePtr &node) {
}
// check if a value is greater or equal UPPER_FLT_LIMIT
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); }
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) {
return IsNodeScalarTrueWith(node, ScalarCheckingMode::GREATER_EQUAL, UPPER_FLT_LIMIT);
}
// check if a value is smaller than LOWER_FLT_LIMIT
bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); }
bool IsNodeScalarMinFLT(const AnfNodePtr &node) {
return IsNodeScalarTrueWith(node, ScalarCheckingMode::LESS, LOWER_FLT_LIMIT);
}
AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode x, y, z;

View File

@ -57,13 +57,8 @@ using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
using mindspore::abstract::AnalysisContextPtr;
using mindspore::validator::Validate;
bool SimplifyDataStructuresPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
namespace {
void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const ResourcePtr &res) {
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
@ -73,6 +68,15 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
res->set_func_graph(new_fg);
}
res->set_args_spec(args_spec);
}
} // namespace
bool SimplifyDataStructuresPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
DoRenormalize(changed, func_graph, res);
return true;
}
@ -99,16 +103,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
bool changed = opt::CleanAfterOptA(func_graph, res->manager());
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
if (changed) {
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
}
res->set_args_spec(args_spec);
DoRenormalize(changed, func_graph, res);
return true;
}

View File

@ -101,7 +101,8 @@ std::unordered_map<abstract::AbstractBasePtrList, int64_t, abstract::AbstractBas
namespace {
std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
std::ostringstream oss;
oss << std::setfill('0') << std::setw(2) << stage_idx << "_" << action_name;
int spaces = 2;
oss << std::setfill('0') << std::setw(spaces) << stage_idx << "_" << action_name;
return oss.str();
}

View File

@ -138,8 +138,8 @@ class NoneOf(NoneOf_):
def __init__(self, patterns=None):
r"""
Args:
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each element
should be one of the exposed Pattern instance.
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each
element should be one of the exposed Pattern instance.
Raises:
TypeError: raise type error for invalid argument.