forked from mindspore-Ecosystem/mindspore
!16682 Fix code check of ME module
From: @irmo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
f88d193423
|
@ -46,12 +46,12 @@ std::optional<std::string> Common::GetRealPath(const std::string &input_path) {
|
|||
return std::nullopt;
|
||||
}
|
||||
#if defined(SYSTEM_ENV_POSIX)
|
||||
if (nullptr == realpath(prefix_path.c_str(), real_path)) {
|
||||
if (realpath(prefix_path.c_str(), real_path) == nullptr) {
|
||||
MS_LOG(ERROR) << "dir " << prefix_path << " does not exist.";
|
||||
return std::nullopt;
|
||||
}
|
||||
#elif defined(SYSTEM_ENV_WINDOWS)
|
||||
if (nullptr == _fullpath(real_path, prefix_path.c_str(), PATH_MAX)) {
|
||||
if (_fullpath(real_path, prefix_path.c_str(), PATH_MAX) == nullptr) {
|
||||
MS_LOG(ERROR) << "dir " << prefix_path << " does not exist.";
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -273,12 +273,13 @@ std::string Common::AddId(const std::string &filename, const std::string &suffix
|
|||
static size_t g_id = 0;
|
||||
std::ostringstream s;
|
||||
auto i = filename.rfind(suffix);
|
||||
int spaces = 4;
|
||||
if (i >= filename.size()) {
|
||||
s << filename;
|
||||
s << "_" << std::setfill('0') << std::setw(4) << g_id;
|
||||
s << "_" << std::setfill('0') << std::setw(spaces) << g_id;
|
||||
} else {
|
||||
s << filename.substr(0, i);
|
||||
s << "_" << std::setfill('0') << std::setw(4) << g_id;
|
||||
s << "_" << std::setfill('0') << std::setw(spaces) << g_id;
|
||||
if (i + 1 < filename.size()) {
|
||||
s << filename.substr(i);
|
||||
}
|
||||
|
|
|
@ -48,7 +48,28 @@ DataType InferType(const AnyPtrList &list) {
|
|||
return DataType::kUnknown;
|
||||
}
|
||||
|
||||
enum OpType { ADD, SUB, MUL, DIV, MOD };
|
||||
template <typename T>
|
||||
bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) {
|
||||
return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) {
|
||||
return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) {
|
||||
return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) ||
|
||||
(x < 0 && y > 0 && (min / y) > x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsDivOverflow(const T &x, const T &y, const T &max, const T &min) {
|
||||
return (x == min && static_cast<int64_t>(y) == -1);
|
||||
}
|
||||
|
||||
enum class OpType { ADD, SUB, MUL, DIV, MOD };
|
||||
|
||||
template <typename T>
|
||||
bool IsSignedIntOverflow(T x, T y, OpType opType) {
|
||||
|
@ -56,20 +77,19 @@ bool IsSignedIntOverflow(T x, T y, OpType opType) {
|
|||
auto min = std::numeric_limits<T>::min();
|
||||
|
||||
if (opType == OpType::ADD) {
|
||||
return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
|
||||
return IsAddOverflow<T>(x, y, max, min);
|
||||
}
|
||||
|
||||
if (opType == OpType::SUB) {
|
||||
return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
|
||||
return IsSubOverflow<T>(x, y, max, min);
|
||||
}
|
||||
|
||||
if (opType == OpType::MUL) {
|
||||
return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) ||
|
||||
(x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x);
|
||||
return IsMulOverflow<T>(x, y, max, min);
|
||||
}
|
||||
|
||||
if (opType == OpType::DIV || opType == OpType::MOD) {
|
||||
return x == min && static_cast<int64_t>(y) == -1;
|
||||
return IsDivOverflow<T>(x, y, max, min);
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Unsupported operation type.";
|
||||
|
|
|
@ -199,7 +199,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
|
|||
|
||||
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
||||
const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
|
||||
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) {
|
||||
const std::set<size_t> &write_indices, std::vector<AnfNodePtr> *const op_inputs) {
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||
[](const Signature &sig) { return sig.dtype; });
|
||||
|
@ -244,12 +244,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
|
||||
const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) {
|
||||
// args: original inputs
|
||||
auto &signature = GetSignature(function);
|
||||
std::size_t sig_size = signature.size();
|
||||
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
|
||||
void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list,
|
||||
const std::string &func_name) {
|
||||
if (sig_size > 0) {
|
||||
if (has_var) {
|
||||
if (sig_size - 1 > args_spec_list.size()) {
|
||||
|
@ -260,6 +256,15 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
|
||||
const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) {
|
||||
// args: original inputs
|
||||
auto &signature = GetSignature(function);
|
||||
std::size_t sig_size = signature.size();
|
||||
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
|
||||
CheckSigSize(sig_size, has_var, args_spec_list, func_name);
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
std::set<size_t> write_indices;
|
||||
std::vector<TypePtr> input_types;
|
||||
|
@ -308,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
}
|
||||
// process default
|
||||
ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
|
||||
DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices);
|
||||
DoAutoCast(func_name, signature, input_types, func_graph, write_indices, &op_inputs);
|
||||
return func_graph->NewCNodeInOrder(op_inputs);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -22,7 +22,7 @@ namespace irpass {
|
|||
#define UPPER_FLT_LIMIT (FLT_MAX / 2.0)
|
||||
#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0)
|
||||
// Define the checking mode
|
||||
enum ScalarCheckingMode : int64_t { GREATER_EQUAL = 0, LESS };
|
||||
enum class ScalarCheckingMode : int64_t { GREATER_EQUAL = 0, LESS };
|
||||
|
||||
bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
|
@ -38,7 +38,7 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec
|
|||
auto scalar = value->cast<ScalarPtr>();
|
||||
if (scalar != nullptr) {
|
||||
if (scalar->isa<FloatImm>()) {
|
||||
if (checking_mode == GREATER_EQUAL) {
|
||||
if (checking_mode == ScalarCheckingMode::GREATER_EQUAL) {
|
||||
return GetValue<float>(scalar) >= check_value;
|
||||
}
|
||||
return GetValue<float>(scalar) < check_value;
|
||||
|
@ -56,7 +56,7 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec
|
|||
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
|
||||
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
|
||||
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
|
||||
if (checking_mode == GREATER_EQUAL) {
|
||||
if (checking_mode == ScalarCheckingMode::GREATER_EQUAL) {
|
||||
return data[0] >= check_value;
|
||||
}
|
||||
return data[0] < check_value;
|
||||
|
@ -66,7 +66,9 @@ bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &chec
|
|||
}
|
||||
|
||||
// check if a value is greater or equal 0.0
|
||||
bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); }
|
||||
bool IsNodeScalarPositive(const AnfNodePtr &node) {
|
||||
return IsNodeScalarTrueWith(node, ScalarCheckingMode::GREATER_EQUAL, 0.0);
|
||||
}
|
||||
|
||||
bool IsCNodePositive(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) {
|
||||
|
@ -87,10 +89,14 @@ bool IsCNodePositive(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
// check if a value is greater or equal UPPER_FLT_LIMIT
|
||||
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); }
|
||||
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) {
|
||||
return IsNodeScalarTrueWith(node, ScalarCheckingMode::GREATER_EQUAL, UPPER_FLT_LIMIT);
|
||||
}
|
||||
|
||||
// check if a value is smaller than LOWER_FLT_LIMIT
|
||||
bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); }
|
||||
bool IsNodeScalarMinFLT(const AnfNodePtr &node) {
|
||||
return IsNodeScalarTrueWith(node, ScalarCheckingMode::LESS, LOWER_FLT_LIMIT);
|
||||
}
|
||||
|
||||
AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
PatternNode x, y, z;
|
||||
|
|
|
@ -57,13 +57,8 @@ using CompileGraphs = compile::CompileGraphs;
|
|||
using abstract::AnalysisResult;
|
||||
using mindspore::abstract::AnalysisContextPtr;
|
||||
using mindspore::validator::Validate;
|
||||
|
||||
bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
|
||||
|
||||
namespace {
|
||||
void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const ResourcePtr &res) {
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
|
@ -73,6 +68,15 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
|||
res->set_func_graph(new_fg);
|
||||
}
|
||||
res->set_args_spec(args_spec);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
|
||||
DoRenormalize(changed, func_graph, res);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -99,16 +103,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
|
|||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
bool changed = opt::CleanAfterOptA(func_graph, res->manager());
|
||||
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
if (changed) {
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
}
|
||||
res->set_args_spec(args_spec);
|
||||
DoRenormalize(changed, func_graph, res);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -101,7 +101,8 @@ std::unordered_map<abstract::AbstractBasePtrList, int64_t, abstract::AbstractBas
|
|||
namespace {
|
||||
std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
|
||||
std::ostringstream oss;
|
||||
oss << std::setfill('0') << std::setw(2) << stage_idx << "_" << action_name;
|
||||
int spaces = 2;
|
||||
oss << std::setfill('0') << std::setw(spaces) << stage_idx << "_" << action_name;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -138,8 +138,8 @@ class NoneOf(NoneOf_):
|
|||
def __init__(self, patterns=None):
|
||||
r"""
|
||||
Args:
|
||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each element
|
||||
should be one of the exposed Pattern instance.
|
||||
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbidden patterns, each
|
||||
element should be one of the exposed Pattern instance.
|
||||
|
||||
Raises:
|
||||
TypeError: raise type error for invalid argument.
|
||||
|
|
Loading…
Reference in New Issue