update clang format rule

This commit is contained in:
zhoufeng 2020-04-21 21:21:19 +08:00
parent 31a12009dd
commit c2b3360d69
278 changed files with 4447 additions and 4441 deletions

View File

@ -94,7 +94,7 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
PointerAlignment: Right
RawStringFormats:
- Language: Cpp
Delimiters:

View File

@ -23,7 +23,7 @@ namespace common {
const int CACHED_STR_NUM = 1 << 8;
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
const char* SafeCStr(const std::string&& str) {
const char *SafeCStr(const std::string &&str) {
static std::atomic<uint32_t> index{0};
uint32_t cur_index = index++;
cur_index = cur_index & CACHED_STR_MASK;

View File

@ -21,16 +21,16 @@
#include <string>
#define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType&) = delete; \
ClassType& operator=(const ClassType&) = delete;
ClassType(const ClassType &) = delete; \
ClassType &operator=(const ClassType &) = delete;
namespace mindspore {
namespace common {
inline const char* SafeCStr(const std::string& str) { return str.c_str(); }
const char* SafeCStr(const std::string&& str);
inline const char *SafeCStr(const std::string &str) { return str.c_str(); }
const char *SafeCStr(const std::string &&str);
static inline std::string GetEnv(const std::string& envvar) {
const char* value = ::getenv(envvar.c_str());
static inline std::string GetEnv(const std::string &envvar) {
const char *value = ::getenv(envvar.c_str());
if (value == nullptr) {
return std::string();

View File

@ -34,11 +34,11 @@ class DecodeOp : public TensorOp {
~DecodeOp() = default;
Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
void Print(std::ostream& out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;
void Print(std::ostream &out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private:
bool is_rgb_format_ = true;

View File

@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int
rnd_.seed(seed_);
}
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input,
std::vector<std::shared_ptr<Tensor>>* output) {
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) {
IO_CHECK_VECTOR(input, output);
if (input.size() != NumInput())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5");
@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso
return Status::OK();
}
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs,
std::vector<TensorShape>& outputs) {
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs,
std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
TensorShape out = TensorShape{-1, -1};
@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp
if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
}
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) {
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
outputs[0] = inputs[0];
return Status::OK();

View File

@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp {
~DistortBoundingBoxCropOp() override = default;
void Print(std::ostream& out) const override {
void Print(std::ostream &out) const override {
out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_;
}
Status Compute(const std::vector<std::shared_ptr<Tensor>>& input,
std::vector<std::shared_ptr<Tensor>>* output) override;
Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) override;
uint32_t NumInput() override { return 5; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private:
int32_t max_attempts_;

View File

@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
rnd_.seed(GetSeed());
}
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) {
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal");
@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std:
(void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_);
}
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) {
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
TensorShape out = TensorShape{target_height_, target_width_};
@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs
if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
}
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) {
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {
double scale, aspect;
*crop_width = w_in;
*crop_height = h_in;

View File

@ -22,7 +22,7 @@
namespace mindspore {
constexpr char PARALLEL_STRATEGY[] = "strategy";
void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false);
void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false);
} // namespace mindspore

View File

@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF;
// get MindSpore Intermediate Representation Path
std::string GetMsIrPath(void) {
std::string path;
const char* path_ptr = getenv("MS_IR_PATH");
const char *path_ptr = getenv("MS_IR_PATH");
if (path_ptr != nullptr) {
path = path_ptr;
char real_path[PATH_MAX] = {0};
@ -62,13 +62,13 @@ std::string GetMsIrPath(void) {
return path;
}
std::string dump_obj(const py::object& obj, const std::string& path) {
std::string dump_obj(const py::object &obj, const std::string &path) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path));
return py::str(name);
}
py::object load_obj(const std::string& path) {
py::object load_obj(const std::string &path) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path));
return obj;
@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) {
// ============================================= MindSpore IR Exporter =============================================
std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
TypePtr type = dyn_cast<Type>(nd->Type());
std::ostringstream oss;
@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
return oss.str();
}
std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const {
std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const {
std::string pkl_path = GetMsIrPath();
// if not specified env 'MS_IR_PATH', do not create any files
if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) {
@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca
return file_prefix + file_name;
}
int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) {
int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp) {
if (func_graph == nullptr || param == nullptr) {
return -1;
}
@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
// try to find index of parameter for SymbolicKeyInstance from all exported graphs
// NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different
int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) {
int AnfExporter::GetParamIndexFromExported(const AnfNodePtr &param) {
if (param == nullptr) {
return -1;
}
int ret = -1;
for (const auto& item : exported) {
for (const auto &item : exported) {
auto pram_iter = item.second.find(param);
if (pram_iter != item.second.end()) {
return pram_iter->second;
@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) {
return ret;
}
std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) {
std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return GetValueText(fg, node->value());
}
std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) {
std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) {
auto py_funcs = mt_func_graph->GetPyFunctions();
if (py_funcs.empty()) {
return "";
@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
oss << "{";
bool is_first = true;
for (const auto& py_func : py_funcs) {
for (const auto &py_func : py_funcs) {
if (is_first) {
is_first = false;
} else {
@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
* GradOperation
* TupleAdd
*/
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) {
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
if (meta_func_graph == nullptr) {
return "";
}
@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_
return oss.str();
}
std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) {
std::ostringstream oss;
if (prim == nullptr) {
return oss.str();
@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
auto& func = do_signature->function();
auto &func = do_signature->function();
if (func->isa<Primitive>()) {
auto sig_prim = dyn_cast<Primitive>(func);
oss << sig_prim->GetAttrsText();
@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
return oss.str();
}
std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) {
std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) {
std::ostringstream oss;
if (ns == nullptr) {
return oss.str();
@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) {
return oss.str();
}
std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph,
const SymbolicKeyInstancePtr& sym_inst) {
std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph,
const SymbolicKeyInstancePtr &sym_inst) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(sym_inst);
AnfNodePtr sym_node = sym_inst->node();
@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra
return oss.str();
}
std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss;
// output ValueList, ValueTuple
ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value);
@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V
return oss.str();
}
std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss;
ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>();
oss << "{";
bool first_flag = true;
for (const auto& elem : dict->value()) {
for (const auto &elem : dict->value()) {
if (first_flag) {
first_flag = false;
} else {
@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
return oss.str();
}
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) {
std::ostringstream oss;
if (check_integrity_) {
@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr&
return oss.str();
}
std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) {
std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss;
bool is_null_ptr = (func_graph == nullptr || value == nullptr);
if (is_null_ptr) {
@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu
}
// this function is used to output node in CNode's inputs
std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
const std::map<AnfNodePtr, int>& apply_map) {
std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int> &apply_map) {
std::ostringstream oss;
if (func_graph == nullptr || node == nullptr) {
return oss.str();
@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
return oss.str();
}
void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map) {
void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) {
bool first_flag = true;
for (const AnfNodePtr& param : parameters) {
for (const AnfNodePtr &param : parameters) {
if (first_flag) {
first_flag = false;
ofs << " ";
@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNode
}
}
void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& node) {
void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) {
if (node == nullptr) {
return;
}
// output type of each input argument
auto& inputs = node->inputs();
auto &inputs = node->inputs();
if (inputs.size() > 1) {
ofs << " #(";
for (size_t i = 1; i < inputs.size(); ++i) {
@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod
ofs << " #scope: " << node->scope()->name();
}
void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes,
const FuncGraphPtr& func_graph) {
void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes,
const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
int idx = 1;
std::map<AnfNodePtr, int> apply_map;
for (const AnfNodePtr& node : nodes) {
for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
}
auto cnode = node->cast<CNodePtr>();
auto& inputs = cnode->inputs();
auto &inputs = cnode->inputs();
std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map);
// non-return node
if (node != func_graph->get_return()) {
@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>
}
}
void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) {
void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun
ofs << "}\n";
}
void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) {
void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt
ofs.close();
}
void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
if (graphs.empty()) {
return;
}
@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
param_index = 1;
for (const auto& tagged_graph : graphs) {
for (const auto &tagged_graph : graphs) {
tagged_cnodes_ = tagged_graph.second;
ExportOneFuncGraph(ofs, tagged_graph.first);
tagged_cnodes_.clear();
@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
}
#ifdef ENABLE_DUMP_IR
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) {
void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap
ChangeFileMode(filename, S_IRUSR);
}
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
AnfExporter exporter("", false);
ChangeFileMode(filename, S_IRWXU);
exporter.ExportFuncGraph(filename, graphs);
@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graph
ChangeFileMode(filename, S_IRUSR);
}
#else
void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) {
void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script.";
}
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) {
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) {
static bool already_printed = false;
if (already_printed) {
return;
@ -732,7 +732,7 @@ enum Token : int {
TOK_ERROR // file read error
};
std::map<Token, const char*> token_text = {
std::map<Token, const char *> token_text = {
{TOK_INVALID, "invalid"}, // invalid token
{TOK_LPARENTHESIS, "("}, // ( left parenthesis
{TOK_RPARENTHESIS, ")"}, // ) right parenthesis
@ -761,14 +761,14 @@ std::map<Token, const char*> token_text = {
class Lexer {
public:
// filename is checked in ImportIR;
explicit Lexer(const char* filename) : fin(filename) {}
explicit Lexer(const char *filename) : fin(filename) {}
~Lexer() {
try {
if (fin.is_open()) {
fin.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file";
} catch (...) {
std::string exName(abi::__cxa_current_exception_type()->name());
@ -776,7 +776,7 @@ class Lexer {
}
}
bool IsSingleCharToken(char ch, Token* token_ptr) {
bool IsSingleCharToken(char ch, Token *token_ptr) {
// clang-format off
std::unordered_map<char, Token> char_to_token = {
{'(', TOK_LPARENTHESIS},
@ -806,7 +806,7 @@ class Lexer {
Token GetNextToken() {
#ifdef DEBUG
Token token = GetNextTokenInner();
const char* str = token_text[token];
const char *str = token_text[token];
std::string text = (str == nullptr ? GetTokenText() : str);
MS_LOG(DEBUG) << "------Parse token] " << text;
return token;
@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE;
class IrParser {
public:
explicit IrParser(const char* filename) : lexer_(filename) {}
explicit IrParser(const char *filename) : lexer_(filename) {}
~IrParser() {}
py::object LoadObject(const std::string& file_name) const {
py::object LoadObject(const std::string &file_name) const {
std::string pkl_path = GetMsIrPath();
py::object default_obj = load_obj(pkl_path + "/" + file_name);
return default_obj;
@ -1087,7 +1087,7 @@ class IrParser {
MS_LOG(INFO) << "Total graphs: " << func_graphs_.size();
}
Token ParseParent(FuncGraphPtr* const parent_ptr) {
Token ParseParent(FuncGraphPtr *const parent_ptr) {
if (lexer_.GetNextToken() != TOK_IDENTIFIER) {
return TOK_ERROR;
}
@ -1168,7 +1168,7 @@ class IrParser {
return func_graph;
}
FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) {
FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) {
Token tok = lexer_.SkipWhiteToken();
while (tok == TOK_VARIABLE) {
if (ParseStatement(func_graph) == nullptr) {
@ -1264,56 +1264,56 @@ class IrParser {
return func_graph;
}
void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const {
void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const {
if (ptr == nullptr) {
return;
}
*ptr = dtype;
}
void SetTupleType(TypePtr* ptr) {
void SetTupleType(TypePtr *ptr) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<Tuple>();
}
void SetTupleType(TypePtr* ptr, const TypePtrList& elems) {
void SetTupleType(TypePtr *ptr, const TypePtrList &elems) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<Tuple>(elems);
}
void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector<int>&) {
void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<TensorType>(elem_type);
}
void SetListType(TypePtr* ptr) {
void SetListType(TypePtr *ptr) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<List>();
}
void SetListType(TypePtr* ptr, const TypePtrList& elems) {
void SetListType(TypePtr *ptr, const TypePtrList &elems) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<List>(elems);
}
void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) {
void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<JTagged>(elem);
}
void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const {
void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const {
if (ptr == nullptr) {
return;
}
@ -1321,45 +1321,45 @@ class IrParser {
}
// void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {}
void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const {
void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<abstract::AbstractNone>();
}
void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {}
void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {}
void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {}
void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {}
void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) {
void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
if (ptr == nullptr) {
return;
}
// if one of elems is nullptr, just return
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) {
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
return;
}
*ptr = std::make_shared<abstract::AbstractTuple>(elems);
}
void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector<int>& shape) {
void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) {
if (ptr == nullptr) {
return;
}
*ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape);
}
void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) {
void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) {
if (ptr == nullptr) {
return;
}
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) {
if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) {
return;
}
*ptr = std::make_shared<abstract::AbstractList>(elems);
}
void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) {
void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) {
if (ptr == nullptr) {
return;
}
@ -1367,7 +1367,7 @@ class IrParser {
}
template <typename T>
Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) {
Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) {
if (tok != TOK_LBRACKET) {
MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol.";
return tok;
@ -1415,7 +1415,7 @@ class IrParser {
}
template <typename T>
Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) {
Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
if (tok != TOK_LPARENTHESIS) {
if (ptr != nullptr) {
SetBasicType(ptr, std::make_shared<TensorType>());
@ -1454,7 +1454,7 @@ class IrParser {
return lexer_.GetNextToken();
}
bool IsNumberType(const std::string& type, TypeId* typeid_ptr) {
bool IsNumberType(const std::string &type, TypeId *typeid_ptr) {
// clang-format off
static std::unordered_map<std::string, TypeId> basic_types = {
{"Bool", kNumberTypeBool},
@ -1486,7 +1486,7 @@ class IrParser {
}
template <typename T>
void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) {
void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) {
TypePtr dtype = nullptr;
std::unordered_map<int, TypePtr> type_map = {
@ -1519,7 +1519,7 @@ class IrParser {
}
template <typename T>
Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) {
Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) {
if (type == "NoneType") {
SetBasicType(ptr, std::make_shared<TypeNone>());
return lexer_.GetNextToken();
@ -1541,7 +1541,7 @@ class IrParser {
}
template <typename T>
Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) {
Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) {
if (tok != TOK_IDENTIFIER) {
return TOK_ERROR;
}
@ -1588,11 +1588,11 @@ class IrParser {
}
}
Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) {
Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) {
return ParseOneType(func_graph, lexer_.GetNextToken(), abstract);
}
Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) {
Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
Token tok = ParseAttribute(func_graph, prim);
while (tok == TOK_COMMA) {
tok = ParseAttribute(func_graph, prim);
@ -1603,7 +1603,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) {
Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) {
Token tok = lexer_.GetNextToken();
if (tok != TOK_IDENTIFIER) {
return TOK_ERROR;
@ -1670,7 +1670,7 @@ class IrParser {
return tok == TOK_RPARENTHESIS ? func_graph : nullptr;
}
FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr>* const inputs_ptr) {
FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
Token tok = ParseArgument(func_graph, inputs_ptr);
while (tok == TOK_COMMA) {
tok = ParseArgument(func_graph, inputs_ptr);
@ -1681,9 +1681,9 @@ class IrParser {
return func_graph;
}
AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) {
AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string &param_name) {
while (func_graph != nullptr) {
for (auto& ptr : func_graph->parameters()) {
for (auto &ptr : func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(ptr);
ParameterPtr param = ptr->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
@ -1701,12 +1701,12 @@ class IrParser {
return nullptr;
}
bool Match(const std::string& str, const std::string& pattern) const {
bool Match(const std::string &str, const std::string &pattern) const {
return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0;
}
template <typename T, typename V>
Token ParseScalar(ValuePtr* const val_ptr) {
Token ParseScalar(ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_NUMBER) {
return TOK_ERROR;
}
@ -1725,7 +1725,7 @@ class IrParser {
}
template <typename VT, typename V, typename T>
Token ParseScalar(ValuePtr* const val_ptr, Token tok) {
Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
if (tok != TOK_LPARENTHESIS) {
*val_ptr = std::make_shared<T>();
return tok;
@ -1735,7 +1735,7 @@ class IrParser {
}
template <typename VT, typename V, typename T, const unsigned nbits>
Token ParseScalar(ValuePtr* const val_ptr, Token tok) {
Token ParseScalar(ValuePtr *const val_ptr, Token tok) {
if (tok != TOK_LPARENTHESIS) {
*val_ptr = std::make_shared<T>(nbits);
return tok;
@ -1745,7 +1745,7 @@ class IrParser {
}
template <typename T>
T StringToScalar(const std::string& text) {
T StringToScalar(const std::string &text) {
std::stringstream ss;
T value;
ss << text;
@ -1753,7 +1753,7 @@ class IrParser {
return value;
}
Token ParseTensor(ValuePtr* const val_ptr) {
Token ParseTensor(ValuePtr *const val_ptr) {
// parse type
TypeId type;
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
@ -1803,7 +1803,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParsePrimType(Token tok, PrimType* prim_type_ptr) {
Token ParsePrimType(Token tok, PrimType *prim_type_ptr) {
if (tok != TOK_LBRACE) {
return tok;
}
@ -1830,7 +1830,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) {
Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
if (tok != TOK_LPARENTHESIS) {
return TOK_ERROR;
}
@ -1855,7 +1855,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) {
Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) {
if (tok != TOK_LBRACE) {
return tok;
}
@ -1868,7 +1868,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParseBoolValue(const std::string& key, bool* val_ptr) {
Token ParseBoolValue(const std::string &key, bool *val_ptr) {
if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) {
return TOK_ERROR;
}
@ -1892,7 +1892,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) {
Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_LBRACE) {
return TOK_ERROR;
}
@ -1920,7 +1920,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) {
Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) {
if (lexer_.GetNextToken() != TOK_LPARENTHESIS) {
return TOK_ERROR;
}
@ -1951,7 +1951,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) {
Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) {
if (lexer_.GetNextToken() != TOK_AT_FILE) {
return TOK_ERROR;
}
@ -1984,7 +1984,7 @@ class IrParser {
return next;
}
Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) {
Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) {
if (Match(id, "MultitypeFuncGraph::")) {
std::string name = id.substr(strlen("MultitypeFuncGraph::"));
auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name);
@ -2024,8 +2024,8 @@ class IrParser {
}
}
Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr,
AnfNodePtr* const node_ptr = nullptr) {
Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr,
AnfNodePtr *const node_ptr = nullptr) {
if (id == "None") {
*val_ptr = std::make_shared<None>();
return lexer_.GetNextToken();
@ -2075,9 +2075,9 @@ class IrParser {
}
}
Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid,
const std::vector<ValuePtr>& elems, const std::vector<AnfNodePtr>& nodes,
ValuePtr* const val_ptr, AnfNodePtr* node_ptr) {
Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid,
const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes,
ValuePtr *const val_ptr, AnfNodePtr *node_ptr) {
if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) {
if (node_is_valid && node_ptr != nullptr) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -2097,8 +2097,8 @@ class IrParser {
}
}
Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr,
AnfNodePtr* node_ptr = nullptr) {
Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr,
AnfNodePtr *node_ptr = nullptr) {
Token left_tok = tok;
std::vector<ValuePtr> elems;
@ -2138,7 +2138,7 @@ class IrParser {
return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr);
}
Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) {
Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) {
// tuple or list
if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) {
return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr);
@ -2152,7 +2152,7 @@ class IrParser {
return TOK_ERROR;
}
Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr,
Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr,
Token tok = TOK_INVALID) {
if (tok == TOK_INVALID) {
tok = lexer_.GetNextToken();
@ -2193,7 +2193,7 @@ class IrParser {
return lexer_.GetNextToken();
}
Token ParseArgument(const FuncGraphPtr& func_graph, std::vector<AnfNodePtr>* const inputs_ptr) {
Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) {
Token tok = lexer_.GetNextToken();
if (tok == TOK_RPARENTHESIS) {
return tok;
@ -2208,7 +2208,7 @@ class IrParser {
return tok;
}
const std::vector<FuncGraphPtr>& GetFuncGraphs() const { return func_graphs_; }
const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; }
private:
Lexer lexer_;
@ -2226,14 +2226,14 @@ class IrParser {
std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter
};
std::vector<FuncGraphPtr> ImportIR(const std::string& filename) {
std::vector<FuncGraphPtr> ImportIR(const std::string &filename) {
IrParser parser(filename.c_str());
parser.ParseFile();
return parser.GetFuncGraphs();
}
#ifdef ENABLE_DUMP_IR
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func graph is nullptr";
return;
@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
return;
}
char real_path[PATH_MAX] = {0};
char* real_path_ret = nullptr;
char *real_path_ret = nullptr;
#if defined(_WIN32) || defined(_WIN64)
real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX);
#else
@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
ChangeFileMode(file_path, S_IRUSR);
}
#else
void DumpIRProto(const FuncGraphPtr&, const std::string&) {
void DumpIRProto(const FuncGraphPtr &, const std::string &) {
static bool already_printed = false;
if (already_printed) {
return;

View File

@ -39,7 +39,7 @@
namespace mindspore {
struct ParamPtrEqual {
bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const {
bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const {
const ParameterPtr param1 = dyn_cast<Parameter>(t1);
const ParameterPtr param2 = dyn_cast<Parameter>(t2);
@ -52,7 +52,7 @@ struct ParamPtrEqual {
};
struct ParamPtrHasher {
std::size_t operator()(AnfNodePtr const& param) const {
std::size_t operator()(AnfNodePtr const &param) const {
const ParameterPtr parameter = dyn_cast<Parameter>(param);
if (parameter == nullptr) {
return 0;
@ -64,39 +64,39 @@ struct ParamPtrHasher {
class AnfExporter {
public:
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false)
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
func_graph_set.clear();
exported.clear();
}
virtual ~AnfExporter() {}
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs);
protected:
virtual std::string GetNodeType(const AnfNodePtr& nd);
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr& param);
std::string DumpObject(const py::object& obj, const std::string& category) const;
std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst);
std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetPrimitiveText(const PrimitivePtr& prim);
std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetNameSpaceText(const parse::NameSpacePtr& ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
const std::map<AnfNodePtr, int>& apply_map);
void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph);
void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map);
virtual std::string GetNodeType(const AnfNodePtr &nd);
int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr &param);
std::string DumpObject(const py::object &obj, const std::string &category) const;
std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst);
std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetPrimitiveText(const PrimitivePtr &prim);
std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetNameSpaceText(const parse::NameSpacePtr &ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int> &apply_map);
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map);
void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node);
void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph);
void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
int param_index;
OrderedSet<FuncGraphPtr> func_graph_set{};
@ -108,16 +108,16 @@ class AnfExporter {
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
};
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs);
void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph);
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs);
std::vector<FuncGraphPtr> ImportIR(const std::string& filename);
std::vector<FuncGraphPtr> ImportIR(const std::string &filename);
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_

View File

@ -34,7 +34,7 @@ namespace draw {
namespace {
// Only for ValueNode
std::string ValueType(const ValueNodePtr& node) {
std::string ValueType(const ValueNodePtr &node) {
if (node == nullptr) {
return "";
}
@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) {
return v->type_name();
}
std::string ReplaceSpecialChar(const std::string& str) {
std::string ReplaceSpecialChar(const std::string &str) {
std::ostringstream oss;
for (size_t i = 0; i < str.size(); i++) {
if (str[i] == '<') {
@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) {
} // namespace
// API of debug utils
void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs,
void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
bool is_user) {
if (sub_graphs == nullptr) {
return;
}
for (auto& nd : nodes) {
for (auto &nd : nodes) {
MS_EXCEPTION_IF_NULL(nd);
auto sub_graph = nd->func_graph();
if (sub_graph != nullptr) {
@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st
}
}
void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) {
void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
if (sub_graphs == nullptr) {
return;
}
int dup_idx = 0;
for (auto& nd : nodes) {
for (auto& t : SuccIncoming(nd)) {
for (auto &nd : nodes) {
for (auto &t : SuccIncoming(nd)) {
MS_EXCEPTION_IF_NULL(t);
MS_EXCEPTION_IF_NULL(nd);
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
}
}
void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) {
void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
if (digraph == nullptr) {
return;
}
@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
}
// Draw edge
for (auto& nd : nodes) {
for (auto &nd : nodes) {
auto succs = SuccIncoming(nd);
auto num = succs.size();
for (size_t i = 0; i < num; i++) {
auto& t = succs.at(i);
auto &t = succs.at(i);
MS_EXCEPTION_IF_NULL(t);
if (t->isa<ValueNode>() || t->isa<Parameter>()) {
if ((!is_user) || (i != 0)) {
@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
}
}
void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) {
void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) {
if (func_graph == nullptr) {
return;
}
@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
DrawValueNodes(nodes, &sub_graphs);
// Draw subgraph
for (const auto& gsub : sub_graphs) {
for (const auto &gsub : sub_graphs) {
digraph->SubGraph(gsub.first, gsub.second);
}
@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
}
#ifdef ENABLE_DUMP_IR
void Draw(const std::string& filename, const FuncGraphPtr& func_graph) {
void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
const std::string dot_suffix = ".dot";
std::string filename_with_suffix =
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
DrawByOpt(filename_with_suffix, func_graph, false);
}
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) {
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
DrawByOpt(filename, func_graph, true);
}
#else
void Draw(const std::string&, const FuncGraphPtr&) {
void Draw(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script.";
}
void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) {
void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) {
return "plaintext";
}
std::string Graphviz::Color(const AnfNodePtr& node) {
std::string Graphviz::Color(const AnfNodePtr &node) {
if (node == nullptr) {
return "";
}
@ -259,7 +259,7 @@ void BaseDigraph::Start() {
buffer_ << "compound=true" << std::endl;
}
void BaseDigraph::Head(const AnfNodePtr& node, int id) {
void BaseDigraph::Head(const AnfNodePtr &node, int id) {
if (node == nullptr) {
return;
}
@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) {
}
}
void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
if (node == nullptr) {
return;
}
@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
buffer_ << ":" << idx;
}
void BaseDigraph::Tail(const FuncGraphPtr& func_graph) {
void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
@ -304,12 +304,12 @@ void BaseDigraph::End() {
}
}
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << "parameters_" << key << "[shape=plaintext ";
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
buffer_ << "<tr><td>parameters</td></tr>";
int count = 0;
for (auto& parameter : key->parameters()) {
for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>";
buffer_ << parameter->ToString();
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
buffer_ << "</table>>,];";
}
void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) {
void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
if (key == nullptr || gsub == nullptr) {
return;
}
@ -361,12 +361,12 @@ Digraph::~Digraph() {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file " << filename_;
}
}
static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) {
static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
(void)str.replace(start_pos, from.length(), to);
@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st
return str;
}
static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph_obj);
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
<< "'>";
@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td align='left'>";
int i = 0;
for (const auto& attr : attrs) {
for (const auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</table>>,";
}
static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) {
return;
}
@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
}
}
static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
return;
}
@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
}
graph_obj->buffer() << ">";
int i = 0;
for (auto& attr : attrs) {
for (auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "exception when closing file " << filename_;
}
}

View File

@ -31,9 +31,9 @@ namespace parse = mindspore::parse;
class Graphviz {
public:
Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {}
Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {}
explicit Graphviz(const std::string& name) : name_(name) {}
explicit Graphviz(const std::string &name) : name_(name) {}
virtual ~Graphviz() {}
@ -41,8 +41,8 @@ class Graphviz {
virtual void End() {}
virtual std::string Shape(AnfNodePtr node);
std::string Color(const AnfNodePtr& node);
std::ostringstream& buffer() { return buffer_; }
std::string Color(const AnfNodePtr &node);
std::ostringstream &buffer() { return buffer_; }
std::ostringstream buffer_;
protected:
@ -53,8 +53,8 @@ class Graphviz {
class BaseDigraph : public Graphviz {
public:
BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string& name) : Graphviz(name) {}
BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string &name) : Graphviz(name) {}
~BaseDigraph() override = default;
virtual void Node(AnfNodePtr node, int id = 0) = 0;
@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz {
void Start() override;
void End() override;
virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start);
void FuncGraphParameters(const FuncGraphPtr& key);
void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub);
void FuncGraphParameters(const FuncGraphPtr &key);
void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub);
const std::string& name() const { return name_; }
const std::string &name() const { return name_; }
protected:
void Head(const AnfNodePtr& node, int id = 0);
void Tail(const AnfNodePtr& node, int idx, int id = 0);
void Tail(const FuncGraphPtr& func_graph);
void Head(const AnfNodePtr &node, int id = 0);
void Tail(const AnfNodePtr &node, int idx, int id = 0);
void Tail(const FuncGraphPtr &func_graph);
};
class Digraph : public BaseDigraph {
public:
Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string& name) : BaseDigraph(name) {}
Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string &name) : BaseDigraph(name) {}
~Digraph() override;
void Node(AnfNodePtr node, int id = 0) override;
@ -86,8 +86,8 @@ class Digraph : public BaseDigraph {
class ModelDigraph : public BaseDigraph {
public:
ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {}
ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {}
~ModelDigraph() override;
std::string Shape(AnfNodePtr node) override;
@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph {
};
// API to draw
void Draw(const std::string& filename, const FuncGraphPtr& func_graph);
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void Draw(const std::string &filename, const FuncGraphPtr &func_graph);
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
} // namespace draw
} // namespace mindspore

View File

@ -33,38 +33,38 @@ class ProtoExporter {
ProtoExporter() {}
~ProtoExporter() {}
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
private:
void InitModelInfo();
void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto);
std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
const std::map<AnfNodePtr, size_t>& apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr);
void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto);
void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto);
void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto);
void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto);
void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto);
void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto);
void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto);
std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t> *const_map_ptr);
void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto);
void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto);
void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto);
void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto);
void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto);
void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto);
void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto);
void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto);
void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto,
std::map<AnfNodePtr, size_t>* const_map_ptr);
void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* apply_map_ptr,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto);
void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node,
const std::map<AnfNodePtr, size_t>& apply_map, std::map<AnfNodePtr, size_t>* const_map_ptr,
irpb::GraphProto* graph_proto);
void ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto);
void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto);
void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
std::map<AnfNodePtr, size_t> *const_map_ptr);
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr,
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto);
void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr,
irpb::GraphProto *graph_proto);
void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto);
static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); }
irpb::ModelProto model_;
};
static irpb::DataType GetNumberDataType(const TypePtr& type) {
static irpb::DataType GetNumberDataType(const TypePtr &type) {
switch (type->type_id()) {
case kNumberTypeBool:
return irpb::DT_BOOL;
@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) {
}
}
void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) {
void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
if (type_proto == nullptr) {
return;
}
@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
type_proto->set_data_type(irpb::DT_TENSOR);
if (shape != nullptr && shape->isa<abstract::Shape>()) {
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
for (const auto& elem : shape_info->shape()) {
for (const auto &elem : shape_info->shape()) {
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
}
}
} else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE);
for (const auto& elem_type : tuple_type->elements()) {
for (const auto &elem_type : tuple_type->elements()) {
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
}
} else if (type->isa<TypeType>()) {
@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
} else if (type->isa<List>()) {
ListPtr list_type = dyn_cast<List>(type);
type_proto->set_data_type(irpb::DT_LIST);
for (const auto& elem_type : list_type->elements()) {
for (const auto &elem_type : list_type->elements()) {
SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types());
}
} else if (type->isa<TypeAnything>()) {
@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s
}
}
void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) {
void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) {
if (node == nullptr || type_proto == nullptr) {
return;
}
SetNodeOutputType(node->Type(), node->Shape(), type_proto);
}
void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}
if (val->isa<StringImm>()) {
const StringImmPtr& value = dyn_cast<StringImm>(val);
const StringImmPtr &value = dyn_cast<StringImm>(val);
value_proto->set_dtype(irpb::DT_STRING);
value_proto->set_str_val(value->value());
} else if (val->isa<Scalar>()) {
@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
} else if (val->isa<tensor::Tensor>()) {
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val);
value_proto->set_dtype(irpb::DT_TENSOR);
irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val();
irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val();
tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype()));
for (auto& elem : tensor_ptr->shape()) {
for (auto &elem : tensor_ptr->shape()) {
tensor_proto->add_dims(elem);
}
} else if (val->isa<TensorType>()) {
value_proto->set_dtype(irpb::DT_TYPE);
irpb::TypeProto* type_proto = value_proto->mutable_type_val();
irpb::TypeProto *type_proto = value_proto->mutable_type_val();
type_proto->set_data_type(irpb::DT_TENSOR);
TypePtr elem_type = dyn_cast<TensorType>(val)->element();
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
}
}
void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}
if (val->isa<BoolImm>()) {
const BoolImmPtr& value = dyn_cast<BoolImm>(val);
const BoolImmPtr &value = dyn_cast<BoolImm>(val);
value_proto->set_dtype(irpb::DT_BOOL);
value_proto->set_bool_val(value->value());
} else if (val->isa<Int8Imm>()) {
const Int8ImmPtr& value = dyn_cast<Int8Imm>(val);
const Int8ImmPtr &value = dyn_cast<Int8Imm>(val);
value_proto->set_dtype(irpb::DT_INT8);
value_proto->set_int_val(value->value());
} else if (val->isa<Int16Imm>()) {
const Int16ImmPtr& value = dyn_cast<Int16Imm>(val);
const Int16ImmPtr &value = dyn_cast<Int16Imm>(val);
value_proto->set_dtype(irpb::DT_INT16);
value_proto->set_int_val(value->value());
} else if (val->isa<Int32Imm>()) {
const Int32ImmPtr& value = dyn_cast<Int32Imm>(val);
const Int32ImmPtr &value = dyn_cast<Int32Imm>(val);
value_proto->set_dtype(irpb::DT_INT32);
value_proto->set_int_val(value->value());
} else if (val->isa<Int64Imm>()) {
const Int64ImmPtr& value = dyn_cast<Int64Imm>(val);
const Int64ImmPtr &value = dyn_cast<Int64Imm>(val);
value_proto->set_dtype(irpb::DT_INT64);
value_proto->set_int_val(value->value());
} else if (val->isa<UInt8Imm>()) {
const UInt8ImmPtr& value = dyn_cast<UInt8Imm>(val);
const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val);
value_proto->set_dtype(irpb::DT_UINT8);
value_proto->set_uint_val(value->value());
} else if (val->isa<UInt16Imm>()) {
const UInt16ImmPtr& value = dyn_cast<UInt16Imm>(val);
const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val);
value_proto->set_dtype(irpb::DT_UINT16);
value_proto->set_uint_val(value->value());
} else if (val->isa<UInt32Imm>()) {
const UInt32ImmPtr& value = dyn_cast<UInt32Imm>(val);
const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val);
value_proto->set_dtype(irpb::DT_UINT32);
value_proto->set_uint_val(value->value());
} else if (val->isa<UInt64Imm>()) {
const UInt64ImmPtr& value = dyn_cast<UInt64Imm>(val);
const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val);
value_proto->set_dtype(irpb::DT_UINT64);
value_proto->set_uint_val(value->value());
} else if (val->isa<FP32Imm>()) {
const FP32ImmPtr& value = dyn_cast<FP32Imm>(val);
const FP32ImmPtr &value = dyn_cast<FP32Imm>(val);
value_proto->set_dtype(irpb::DT_FLOAT32);
value_proto->set_float_val(value->value());
} else if (val->isa<FP64Imm>()) {
const FP64ImmPtr& value = dyn_cast<FP64Imm>(val);
const FP64ImmPtr &value = dyn_cast<FP64Imm>(val);
value_proto->set_dtype(irpb::DT_FLOAT64);
value_proto->set_double_val(value->value());
} else {
@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val
}
}
void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}
if (val->isa<ValueTuple>()) {
const ValueTuplePtr& value = dyn_cast<ValueTuple>(val);
const ValueTuplePtr &value = dyn_cast<ValueTuple>(val);
value_proto->set_dtype(irpb::DT_TUPLE);
for (const auto& item : value->value()) {
for (const auto &item : value->value()) {
SetValueToProto(item, value_proto->add_values());
}
} else if (val->isa<ValueList>()) {
const ValueListPtr& value = dyn_cast<ValueList>(val);
const ValueListPtr &value = dyn_cast<ValueList>(val);
value_proto->set_dtype(irpb::DT_LIST);
for (const auto& item : value->value()) {
for (const auto &item : value->value()) {
SetValueToProto(item, value_proto->add_values());
}
}
}
void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) {
void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) {
if (val == nullptr || value_proto == nullptr) {
return;
}
value_proto->set_dtype(irpb::DT_DICT);
for (const auto& item : val->value()) {
irpb::NamedValueProto* named_val = value_proto->add_dict_val();
for (const auto &item : val->value()) {
irpb::NamedValueProto *named_val = value_proto->add_dict_val();
named_val->set_key(item.first);
SetValueToProto(item.second, named_val->mutable_value());
}
}
void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) {
void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) {
if (node == nullptr || node_proto == nullptr) {
return;
}
@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr&
MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString();
}
const PrimitivePtr& prim = GetValueNode<PrimitivePtr>(node);
const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node);
node_proto->set_op_type(prim->name());
for (const auto& attr : prim->attrs()) {
irpb::AttributeProto* attr_proto = node_proto->add_attribute();
for (const auto &attr : prim->attrs()) {
irpb::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first);
SetValueToProto(attr.second, attr_proto->mutable_value());
}
node_proto->set_scope(node->scope()->name());
}
std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node,
const std::map<AnfNodePtr, size_t>& apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr) {
std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t> *const_map_ptr) {
if (node == nullptr || const_map_ptr == nullptr) {
return "";
}
@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt
MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'";
}
std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) {
std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return "";
}
InitModelInfo();
irpb::GraphProto* graph_proto = model_.mutable_graph();
irpb::GraphProto *graph_proto = model_.mutable_graph();
ExportFuncGraph(func_graph, graph_proto);
return model_.SerializeAsString();
}
void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || graph_proto == nullptr) {
return;
}
@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP
ExportValueNodes(const_map, graph_proto);
}
void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || graph_proto == nullptr) {
return;
}
std::vector<AnfNodePtr> parameters = func_graph->parameters();
for (auto& param : parameters) {
irpb::ParameterProto* param_proto = graph_proto->add_parameters();
for (auto &param : parameters) {
irpb::ParameterProto *param_proto = graph_proto->add_parameters();
param_proto->set_name(param->ToString());
SetNodeOutputType(param, param_proto->mutable_type());
@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph
}
}
void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto,
std::map<AnfNodePtr, size_t>* const_map_ptr) {
void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto,
std::map<AnfNodePtr, size_t> *const_map_ptr) {
if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) {
return;
}
// topo sort nodes
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
std::map<AnfNodePtr, size_t> apply_map;
for (const AnfNodePtr& node : nodes) {
for (const AnfNodePtr &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt
}
}
void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* apply_map_ptr,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *apply_map_ptr,
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr ||
graph_proto == nullptr) {
return;
@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
auto apply_idx = apply_map_ptr->size() + 1;
(*apply_map_ptr)[node] = apply_idx;
auto& inputs = node->inputs();
auto &inputs = node->inputs();
if (inputs.size() < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr op = inputs[0];
irpb::NodeProto* node_proto = graph_proto->add_node();
irpb::NodeProto *node_proto = graph_proto->add_node();
// CNode/ConstGraph/Const/Parameter
if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) {
@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
// process OP inputs
for (size_t i = 1; i < inputs.size(); ++i) {
irpb::InputProto* input_proto = node_proto->add_input();
irpb::InputProto *input_proto = node_proto->add_input();
input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE);
std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr);
input_proto->set_name(id);
@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr&
}
}
void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node,
const std::map<AnfNodePtr, size_t>& apply_map,
std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node,
const std::map<AnfNodePtr, size_t> &apply_map,
std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) {
if (ret_node == nullptr || !ret_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Graph return node is illegal";
}
@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
if (graph_proto == nullptr) {
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
}
irpb::OutputProto* output_proto = graph_proto->add_outputs();
irpb::OutputProto *output_proto = graph_proto->add_outputs();
if (output_proto == nullptr) {
MS_LOG(EXCEPTION) << "output_proto is nullptr";
}
@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const
SetNodeOutputType(arg, output_proto->mutable_type());
}
static bool CompareValue(const std::pair<AnfNodePtr, size_t>& x, const std::pair<AnfNodePtr, size_t>& y) {
static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) {
return x.second < y.second;
}
void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto) {
void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) {
std::vector<std::pair<AnfNodePtr, size_t>> nodes;
(void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes),
[](const std::pair<AnfNodePtr, size_t>& item) { return item; });
[](const std::pair<AnfNodePtr, size_t> &item) { return item; });
sort(nodes.begin(), nodes.end(), CompareValue);
for (auto& item : nodes) {
for (auto &item : nodes) {
if (graph_proto == nullptr) {
MS_LOG(EXCEPTION) << "graph_proto is nullptr";
}
irpb::NamedValueProto* named_value = graph_proto->add_const_vals();
irpb::NamedValueProto *named_value = graph_proto->add_const_vals();
MS_EXCEPTION_IF_NULL(named_value);
named_value->set_key(GetConstNodeId(item.second));
SetValueToProto(GetValueNode(item.first), named_value->mutable_value());
@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_m
void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); }
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) {
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) {
ProtoExporter exporter;
return exporter.GetFuncGraphProtoString(func_graph);
}

View File

@ -36,7 +36,7 @@ Dump::Dump()
dump_iter_(0),
cur_iter_(0) {}
bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
bool Dump::IsKernelNeedDump(const std::string &kernel_name) {
if (dump_mode_ == 0) {
// Dump All Kernels mode
return true;
@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
return false;
}
bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
bool Dump::ParseDumpConfig(const std::string &dump_config_file) {
std::ifstream jsonFile(dump_config_file);
if (!jsonFile.is_open()) {
MS_LOG(ERROR) << dump_config_file << " open failed.";
@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
return true;
}
bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) {
if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() ||
dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() ||
dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() ||
@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
return true;
}
bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto trans_flag = dumpSettings.at("trans_flag");
auto enable = dumpSettings.at("enable");
auto mode = dumpSettings.at("mode");
@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
dump_path_ = path;
dump_net_name_ = net_name;
dump_iter_ = iteration;
for (const auto& kernel : kernels) {
for (const auto &kernel : kernels) {
dump_kernels_.push_back(kernel);
}
return true;
}
bool Dump::SetDumpConfFromJsonFile() {
const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
if (config_path_str != nullptr) {
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
} else {
@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() {
return ParseDumpConfig(dump_config_file);
}
bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) {
bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) {
if (filename.empty() || data == nullptr || len == 0) {
MS_LOG(ERROR) << "Incorrect parameter.";
return false;
@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len)
MS_LOG(ERROR) << "Open file " << realpath << " fail.";
return false;
}
(void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len));
(void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len));
fd.close();
return true;
}
bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) {
MS_EXCEPTION_IF_NULL(outpath);
auto path_split_pos = inpath.find_last_of('/');
if (path_split_pos == std::string::npos) {
@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
return true;
}
bool Dump::CreateNotExistDirs(const std::string& path) {
bool Dump::CreateNotExistDirs(const std::string &path) {
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
char temp_path[PATH_MAX] = {0};

View File

@ -43,11 +43,11 @@ class Dump {
uint32_t cur_iter() const { return cur_iter_; }
bool IsKernelNeedDump(const std::string& kernel_name);
bool IsKernelNeedDump(const std::string &kernel_name);
bool SetDumpConfFromJsonFile();
static bool DumpToFile(const std::string& filename, const void* data, size_t len);
static bool DumpToFile(const std::string &filename, const void *data, size_t len);
protected:
bool dump_enable_;
@ -59,14 +59,14 @@ class Dump {
uint32_t cur_iter_;
std::vector<std::string> dump_kernels_;
static bool GetRealPath(const std::string& inpath, std::string* outpath);
static bool GetRealPath(const std::string &inpath, std::string *outpath);
static bool CreateNotExistDirs(const std::string& path);
static bool CreateNotExistDirs(const std::string &path);
private:
bool ParseDumpConfig(const std::string& dump_config_file);
bool IsConfigExist(const nlohmann::json& dumpSettings);
bool IsConfigValid(const nlohmann::json& dumpSettings);
bool ParseDumpConfig(const std::string &dump_config_file);
bool IsConfigExist(const nlohmann::json &dumpSettings);
bool IsConfigValid(const nlohmann::json &dumpSettings);
};
using DumpConfPtr = std::shared_ptr<Dump>;

View File

@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) {
std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {
std::string temp_line = line;
if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) &&
tip != kSourceLineTipDiscard) {
@ -101,14 +101,14 @@ DebugInfo::DebugInfo() {
name_ = "";
}
DebugInfo::DebugInfo(const std::string& name) {
DebugInfo::DebugInfo(const std::string &name) {
InitValueFromContext();
unique_id_ = gen_unique_id();
debug_id_ = -1;
name_ = name;
}
DebugInfo::DebugInfo(const LocationPtr& loc) {
DebugInfo::DebugInfo(const LocationPtr &loc) {
InitValueFromContext();
unique_id_ = gen_unique_id();
debug_id_ = -1;
@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() {
}
int64_t DebugInfo::unique_id_through_copy() const {
TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info();
TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info();
if (trace_info != nullptr) {
if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) {
return trace_info->debug_info()->unique_id_through_copy();
@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() {
}
return DebugInfo::location();
}
void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; }
void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; }
TraceContextPtr TraceManager::CurrentContextInfo() {
if (!TraceManager::trace_context_stack_.empty()) {
@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() {
return nullptr;
}
void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) {
void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location);
context->set_func_name(func_name);
TraceManager::trace_context_stack_.push(context);
}
void TraceManager::DebugTrace(const LocationPtr& location) {
void TraceManager::DebugTrace(const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location);
TraceManager::trace_context_stack_.push(context);
}
void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
}
@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
TraceManager::trace_context_stack_.push(context);
}
void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) {
void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
}

View File

@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou
// Location class record the location in source code.
class Location {
public:
Location(const std::string& file_name, int line, int column, int line_end, int column_end)
Location(const std::string &file_name, int line, int column, int line_end, int column_end)
: file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
Location(const Location& loc)
Location(const Location &loc)
: file_name_(loc.file_name_),
line_(loc.line_),
column_(loc.column_),
@ -77,21 +77,21 @@ class TraceManager {
TraceManager() = default;
~TraceManager() = default;
static TraceContextPtr CurrentContextInfo();
static void DebugTrace(const std::string& func_name, const LocationPtr& location);
static void DebugTrace(const LocationPtr& location);
static void DebugTrace(const TraceInfoPtr& trace_info);
static void DebugTrace(const std::string &func_name, const LocationPtr &location);
static void DebugTrace(const LocationPtr &location);
static void DebugTrace(const TraceInfoPtr &trace_info);
// debug trace with a cloned trace info with debug_info
static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info);
static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info);
static void EndTrace();
static std::stack<TraceContextPtr> trace_context_stack_;
};
class TraceGuard {
public:
explicit TraceGuard(const std::string func_name, const LocationPtr& location) {
explicit TraceGuard(const std::string func_name, const LocationPtr &location) {
TraceManager::DebugTrace(func_name, location);
}
explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); }
explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
~TraceGuard() { TraceManager::EndTrace(); }
};
@ -106,23 +106,23 @@ class TraceContext {
public:
~TraceContext() = default;
explicit TraceContext(const LocationPtr& loc) {
explicit TraceContext(const LocationPtr &loc) {
ProcessAttributeFromContext();
location_ = loc;
}
explicit TraceContext(const std::string& func_name) {
explicit TraceContext(const std::string &func_name) {
ProcessAttributeFromContext();
func_name_ = func_name;
}
explicit TraceContext(const TraceInfoPtr& trace_info) {
explicit TraceContext(const TraceInfoPtr &trace_info) {
ProcessAttributeFromContext();
trace_info_ = trace_info;
}
void set_location(const LocationPtr& loc) { location_ = loc; }
void set_location(const LocationPtr &loc) { location_ = loc; }
LocationPtr location() { return location_; }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; }
void set_func_name(const std::string& func_name) { func_name_ = func_name; }
void set_func_name(const std::string &func_name) { func_name_ = func_name; }
std::string func_name() { return func_name_; }
};
@ -130,9 +130,9 @@ class DebugInfo : public Base {
public:
DebugInfo();
explicit DebugInfo(const std::string& name);
explicit DebugInfo(const std::string &name);
explicit DebugInfo(const LocationPtr& loc);
explicit DebugInfo(const LocationPtr &loc);
virtual ~DebugInfo() = default;
MS_DECLARE_PARENT(DebugInfo, Base);
@ -141,12 +141,12 @@ class DebugInfo : public Base {
int64_t unique_id_through_copy() const;
std::string get_id() { return std::to_string(debug_id()); }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; }
void set_location(const LocationPtr& loc) { location_ = loc; }
void set_location(const LocationPtr &loc) { location_ = loc; }
virtual LocationPtr location() { return location_; }
std::string name() { return name_; }
void set_name(const std::string& name) { name_ = name; }
void set_name(const std::string &name) { name_ = name; }
virtual std::string debug_name();
virtual std::string get_python_func_belonged() { return ""; }
@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo {
py_func_belonged_ = context_info->func_name();
}
}
explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) {
explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo();
py_func_belonged_ = context_info->func_name();
@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo {
~NodeDebugInfo() override = default;
std::string debug_name() override;
void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); }
void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); }
std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; }
void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; }
std::string get_python_func_belonged() override { return py_func_belonged_; }
AnfNodeWeakPtr node_;
std::string py_func_belonged_;
@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo {
}
}
explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) {
explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo();
py_func_name_ = context_info->func_name();
@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo {
std::string debug_name() override;
LocationPtr location() override;
LocationPtr deco_location() { return deco_loc_; }
void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
FuncGraphPtr get_graph() const { return func_graph_.lock(); }
void set_full_name(const std::string& name) { full_name_ = name; }
void set_full_name(const std::string &name) { full_name_ = name; }
std::string get_full_name() { return full_name_; }
void set_deco_location(const LocationPtr& deco_list_loc);
void set_deco_location(const LocationPtr &deco_list_loc);
std::string get_python_func_belonged() override { return py_func_name_; }
FuncGraphWeakPtr func_graph_;
LocationPtr deco_loc_;

View File

@ -31,7 +31,7 @@ struct NameWithTrace {
std::string name;
std::vector<std::string> trace_labels;
};
static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) {
static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
switch (trace_label) {
case TraceLabelType::kShortSymbol:
return trace_info->symbol();
@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t
}
}
NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name;
// find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
auto temp_info = debug_info;
@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe
return trace_name;
}
std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) {
std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) {
std::string tags = "";
for (auto& itr : trace_labels) {
for (auto &itr : trace_labels) {
std::string symbol = itr;
tags = tags + symbol;
}
@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st
}
// get the label name of the node debug info
std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name = RootName(debug_info, trace_label);
return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
}
std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
auto temp_info = debug_info;
std::string label = "";
while (temp_info != nullptr) {
@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
}
// get trace with unique id chain
std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); }
std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
return LabelStringUnique(debug_info);
}

View File

@ -29,7 +29,7 @@ namespace label_manage {
enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId };
TraceLabelType GetGlobalTraceLabelType();
void SetGlobalTraceLabelType(TraceLabelType label_type);
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
} // namespace label_manage
} // namespace mindspore

View File

@ -37,7 +37,7 @@
namespace mindspore {
// namespace to support debug trace infomation
namespace trace {
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) {
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) {
if (abs == nullptr) {
return "Null Abstract";
}
@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) {
return debug_with_loc_vec;
}
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) {
auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info);
if (debug_with_loc_vec.size() > 0) {
return debug_with_loc_vec[0];
@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
}
}
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) {
return "";
}
@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
// a trace info identifies a node transform, so we can trace the node transform through
// a link of trace info and debug info
std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) {
std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) {
if (info_vec.size() < 1) {
return "";
}
@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL
return traced_info;
}
std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) {
return "";
}
@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
return "";
}
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) {
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) {
std::ostringstream oss;
if (info == nullptr) {
return "";
@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So
return oss.str();
}
std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) {
std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) {
std::ostringstream oss;
oss << "graph:" << graph->ToString() << " with args[";
auto params = graph->parameters();
@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas
return oss.str();
}
void DumpInferStack(std::ostringstream& oss) {
auto& infer_stack = GetCurrenGraphInferStack();
void DumpInferStack(std::ostringstream &oss) {
auto &infer_stack = GetCurrenGraphInferStack();
if (infer_stack.empty()) {
return;
}
@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) {
}
std::reverse(infer_vec.begin(), infer_vec.end());
int index = 0;
for (auto& item : infer_vec) {
for (auto &item : infer_vec) {
auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first);
if (graph_infer == nullptr) {
MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator";
@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) {
}
void TraceGraphInfer() {
auto& infer_stack = GetCurrenGraphInferStack();
auto &infer_stack = GetCurrenGraphInferStack();
std::ostringstream oss;
if (infer_stack.empty()) {
return;
@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
~AnalyzedFuncGraphExporter() override = default;
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs);
private:
std::string GetNodeType(const AnfNodePtr& nd) override;
std::string GetNodeType(const AnfNodePtr &nd) override;
};
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
auto& list = GetCNodeDebugStack();
auto &list = GetCNodeDebugStack();
for (size_t i = 0; i < list.size(); ++i) {
auto node_cfg = list[i];
auto fg = node_cfg->context()->func_graph();
@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() {
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
}
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
if (node_cfg_ == nullptr) {
return AnfExporter::GetNodeType(node);
}
@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
return oss.str();
}
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) {
if (node_cfgs.empty()) {
MS_LOG(DEBUG) << "Node configs is empty";
return;
@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
auto tagged_func_graphs = CalcTaggedFuncGraphs();
// first output graph on the analysis stack
for (const auto& node_cfg : node_cfgs) {
for (const auto &node_cfg : node_cfgs) {
auto fg = node_cfg->context()->func_graph();
// the graph is already output, skip it
if (exported.find(fg) != exported.end()) {
@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
ofs.close();
}
void GetInferStackInfo(std::ostringstream& oss) {
void GetInferStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Get graph analysis information begin";
auto stack = GetCNodeDebugStack();
if (stack.empty()) {
@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) {
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
// trace the cnode infer debug info
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) {
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
}
@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An
}
}
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) {
if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
}
@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
}
}
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); }
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; }
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() {
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() {
return graph_infer_stack;
}
void ClearTraceStack() {

View File

@ -31,19 +31,19 @@
namespace mindspore {
namespace trace {
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix,
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
void TraceGraphInfer();
void GetInferStackInfo(std::ostringstream& oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg);
void GetInferStackInfo(std::ostringstream &oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
void TraceInferCNodeLeave();
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs);
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
void ClearTraceStack();
} // namespace trace
} // namespace mindspore

View File

@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {
if (info == nullptr) {
return "";
}

View File

@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>;
// namespace to support intermediate representation definition
class TraceInfo : public Base {
public:
TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) {
TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) {
symbol_ = symbol;
full_name_ = full_name;
name_ = full_name_;
debug_info_ = info;
}
TraceInfo(const TraceInfo& info)
TraceInfo(const TraceInfo &info)
: Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {}
virtual ~TraceInfo() = default;
MS_DECLARE_PARENT(TraceInfo, Base);
@ -55,8 +55,8 @@ class TraceInfo : public Base {
virtual std::string full_name() { return full_name_; }
virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); }
virtual std::string action_name() { return ""; }
virtual std::string GetActionBetweenNode(const DebugInfoPtr& info);
void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; }
virtual std::string GetActionBetweenNode(const DebugInfoPtr &info);
void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; }
DebugInfoPtr debug_info() { return debug_info_; }
DebugInfoPtr DebugInfoHasLoc();
std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo();
@ -70,7 +70,7 @@ class TraceInfo : public Base {
class TracePhi : public TraceInfo {
public:
explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {}
explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {}
MS_DECLARE_PARENT(TracePhi, TraceInfo);
~TracePhi() override = default;
TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); }
@ -78,8 +78,8 @@ class TracePhi : public TraceInfo {
class TraceIfStmtTrueBranch : public TraceInfo {
public:
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default;
explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "") {}
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default;
explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "") {}
MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo);
~TraceIfStmtTrueBranch() override = default;
TraceInfoPtr clone() override {
@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo {
class TraceIfStmtFalseBranch : public TraceInfo {
public:
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default;
explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "") {}
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default;
explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "") {}
MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo);
~TraceIfStmtFalseBranch() override = default;
TraceInfoPtr clone() override {
@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo {
class TraceIfStmtAfterBranch : public TraceInfo {
public:
explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "") {}
explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "") {}
MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo);
~TraceIfStmtAfterBranch() override = default;
TraceInfoPtr clone() override {
@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo {
class TraceIfExpTrueBranch : public TraceInfo {
public:
explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "") {}
explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "") {}
MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo);
~TraceIfExpTrueBranch() override = default;
TraceInfoPtr clone() override {
@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo {
class TraceIfExpFalseBranch : public TraceInfo {
public:
explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "") {}
explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "") {}
MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo);
~TraceIfExpFalseBranch() override = default;
TraceInfoPtr clone() override {
@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo {
class TraceCopy : public TraceInfo {
public:
TraceCopy() : TraceInfo(nullptr, "copy", "") {}
explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {}
explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {}
MS_DECLARE_PARENT(TraceCopy, TraceInfo);
~TraceCopy() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); }
@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo {
class TraceIterator : public TraceInfo {
public:
explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {}
explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {}
MS_DECLARE_PARENT(TraceIterator, TraceInfo);
~TraceIterator() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); }
@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo {
class TraceWhileHeader : public TraceInfo {
public:
explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "") {}
explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "") {}
MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo);
~TraceWhileHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); }
@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo {
class TraceWhileBody : public TraceInfo {
public:
explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "") {}
explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "") {}
MS_DECLARE_PARENT(TraceWhileBody, TraceInfo);
~TraceWhileBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); }
@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo {
class TraceWhileAfter : public TraceInfo {
public:
explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "") {}
explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "") {}
MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo);
~TraceWhileAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); }
@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo {
class TraceForHeader : public TraceInfo {
public:
explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "") {}
explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "") {}
MS_DECLARE_PARENT(TraceForHeader, TraceInfo);
~TraceForHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); }
@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo {
class TraceForBody : public TraceInfo {
public:
explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "") {}
explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "") {}
MS_DECLARE_PARENT(TraceForBody, TraceInfo);
~TraceForBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); }
@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo {
class TraceForAfter : public TraceInfo {
public:
explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "") {}
explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "") {}
MS_DECLARE_PARENT(TraceForAfter, TraceInfo);
~TraceForAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); }
@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo {
class TraceEquiv : public TraceInfo {
public:
explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {}
explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {}
MS_DECLARE_PARENT(TraceEquiv, TraceInfo);
~TraceEquiv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); }
@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo {
class TraceGradFpropApp : public TraceInfo {
public:
TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "") {}
explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "") {}
explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "") {}
MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo);
~TraceGradFpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); }
@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo {
class TraceGradBpropApp : public TraceInfo {
public:
TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "") {}
explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "") {}
explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "") {}
MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo);
~TraceGradBpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); }
@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo {
class TraceGradFprop : public TraceInfo {
public:
TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "") {}
explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "") {}
explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "") {}
MS_DECLARE_PARENT(TraceGradFprop, TraceInfo);
~TraceGradFprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); }
@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo {
class TraceGradBprop : public TraceInfo {
public:
TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "") {}
explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "") {}
explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "") {}
MS_DECLARE_PARENT(TraceGradBprop, TraceInfo);
~TraceGradBprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); }
@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo {
class TraceGradSens : public TraceInfo {
public:
TraceGradSens() : TraceInfo(nullptr, "grad_sens", "") {}
explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "") {}
explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "") {}
MS_DECLARE_PARENT(TraceGradSens, TraceInfo);
~TraceGradSens() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); }
@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo {
class TraceSpecialize : public TraceInfo {
public:
explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
MS_DECLARE_PARENT(TraceSpecialize, TraceInfo);
std::string name() override { return full_name_ + counter_; }
std::string symbol() override { return counter_ + "_"; }
@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo {
class TraceGradOperation : public TraceInfo {
public:
explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {}
explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {}
MS_DECLARE_PARENT(TraceGradOperation, TraceInfo);
~TraceGradOperation() override = default;
TraceInfoPtr clone() override {
@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo {
class TraceForceBool : public TraceInfo {
public:
explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {}
explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {}
MS_DECLARE_PARENT(TraceForceBool, TraceInfo);
~TraceForceBool() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo {
class TraceExpandJ : public TraceInfo {
public:
explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {}
explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}
MS_DECLARE_PARENT(TraceExpandJ, TraceInfo);
~TraceExpandJ() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); }
@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo {
class TraceGenMetaFuncGraph : public TraceInfo {
public:
explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo);
~TraceGenMetaFuncGraph() override = default;
TraceInfoPtr clone() override {
@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo {
class TraceEvaluatorGenGraph : public TraceInfo {
public:
explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo);
~TraceEvaluatorGenGraph() override = default;
TraceInfoPtr clone() override {
@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo {
class TraceResolve : public TraceInfo {
public:
explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {}
explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {}
MS_DECLARE_PARENT(TraceResolve, TraceInfo);
~TraceResolve() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); }
@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo {
class TraceTransform : public TraceInfo {
public:
TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; }
explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") {
explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") {
transform_name_ = transform_name;
}
@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo {
class TraceGenerateVarArg : public TraceInfo {
public:
explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {}
explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {}
MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo);
~TraceGenerateVarArg() override = default;
TraceInfoPtr clone() override {
@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo {
class TraceGenerateKwArg : public TraceInfo {
public:
explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {}
explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {}
MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo);
~TraceGenerateKwArg() override = default;
TraceInfoPtr clone() override {
@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo {
class TraceTrasformK : public TraceInfo {
public:
explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {}
explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {}
MS_DECLARE_PARENT(TraceTrasformK, TraceInfo);
~TraceTrasformK() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); }
@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo {
class TracePartialTransform : public TraceInfo {
public:
explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {}
explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {}
MS_DECLARE_PARENT(TracePartialTransform, TraceInfo);
~TracePartialTransform() override = default;
TraceInfoPtr clone() override {
@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo {
class TraceGetEnv : public TraceInfo {
public:
explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {}
explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
~TraceGetEnv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); }
@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo {
class TraceDoSignature : public TraceInfo {
public:
explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {}
explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {}
MS_DECLARE_PARENT(TraceDoSignature, TraceInfo);
~TraceDoSignature() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); }
@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo {
class TraceCombileLikeGraphs : public TraceInfo {
public:
TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {}
explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {}
explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {}
MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo);
~TraceCombileLikeGraphs() override = default;
TraceInfoPtr clone() override {

View File

@ -21,7 +21,7 @@
namespace mindspore {
namespace device {
namespace ascend {
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (has_malloc_) {
MS_LOG(EXCEPTION) << "Has alloc memory pool memory !";
}
@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return size;
}
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) {
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
MS_EXCEPTION_IF_NULL(addr);
has_malloc_ = false;
free_mem_size_ = total_mem_size_;
@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; }
void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) {
void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
MS_EXCEPTION_IF_NULL(device_mem_pool_base);
device_mem_pool_base_ = device_mem_pool_base;
}

View File

@ -26,12 +26,12 @@ namespace ascend {
class AscendMemoryPool : public DynamicMemPoolBestFit {
public:
~AscendMemoryPool() override = default;
AscendMemoryPool(const AscendMemoryPool&) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
AscendMemoryPool(const AscendMemoryPool &) = delete;
AscendMemoryPool &operator=(const AscendMemoryPool &) = delete;
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override;
void set_device_mem_pool_base(uint8_t* device_mem_pool_base);
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
void set_device_mem_pool_size(uint64_t device_mem_pool_size) {
device_mem_pool_size_ = device_mem_pool_size;
free_mem_size_ = device_mem_pool_size_;
@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
size_t free_mem_size() override;
size_t total_mem_size() override;
static AscendMemoryPool& GetInstance() {
static AscendMemoryPool &GetInstance() {
static AscendMemoryPool instance;
return instance;
}
@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
private:
AscendMemoryPool() = default;
bool has_malloc_{false};
uint8_t* device_mem_pool_base_{nullptr};
uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0};
size_t free_mem_size_{0};
size_t total_mem_size_{0};

View File

@ -39,13 +39,13 @@ using std::vector;
class AscendStreamAssign {
public:
static AscendStreamAssign& GetInstance() {
static AscendStreamAssign &GetInstance() {
static AscendStreamAssign instance; // Guaranteed to be destroyed.
return instance;
}
AscendStreamAssign(const AscendStreamAssign&) = delete;
AscendStreamAssign& operator=(const AscendStreamAssign&) = delete;
AscendStreamAssign(const AscendStreamAssign &) = delete;
AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
uint32_t GetTotalStreamNum() const;
// new stream policy
@ -53,19 +53,19 @@ class AscendStreamAssign {
uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
uint32_t total_event_num() const { return total_event_num_; }
void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ResetNew();
void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
bool IsIndependentNode(const CNodePtr& node_ptr);
const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; }
const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; }
const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; }
void GetWaitStreams(vector<uint32_t>* wait_active_stream_list);
const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; }
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsIndependentNode(const CNodePtr &node_ptr);
const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; }
const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; }
const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; }
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; }
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id);
private:
@ -73,30 +73,30 @@ class AscendStreamAssign {
~AscendStreamAssign() = default;
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr& node);
const CNodePtr &node);
bool IsHcom(const CNodePtr& apply_kernel);
bool IsHcom(const CNodePtr &apply_kernel);
bool IsProcessed(uint32_t logic_id);
void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids);
void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index,
uint32_t* cur_stream_id);
void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids);
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
uint32_t *cur_stream_id);
void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
void UpdateStreamActive(const CNodePtr& active_ptr);
void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr);
void UpdateStreamActive(const CNodePtr &active_ptr);
void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr);
bool IsTaskSink();
void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr);
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
void SetCommonStreamNum(uint32_t cur_stream_id);
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsProcessedParallelStream(uint32_t stream_id);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
uint32_t total_common_stream_num_{0};
uint32_t total_independ_stream_num_{0};

View File

@ -28,14 +28,14 @@ namespace device {
namespace ascend {
class PluginImpl : public PluginIntf {
public:
explicit PluginImpl(const std::string& module);
explicit PluginImpl(const std::string &module);
~PluginImpl() override = default;
int Init(const Reporter* reporter) override;
int Init(const Reporter *reporter) override;
int UnInit() override;
static Reporter* GetPluginReporter() { return reporter_; }
static Reporter *GetPluginReporter() { return reporter_; }
private:
static Reporter* reporter_;
static Reporter *reporter_;
std::string module_;
};
} // namespace ascend

View File

@ -20,12 +20,12 @@
namespace mindspore {
namespace device {
namespace ascend {
PluginIntf* ProfilingEngineImpl::CreatePlugin() {
PluginIntf *ProfilingEngineImpl::CreatePlugin() {
MS_LOG(INFO) << "Create Plugin.";
return new (std::nothrow) PluginImpl("Framework");
}
int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) {
int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) {
if (plugin != nullptr) {
delete plugin;
}

View File

@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf {
ProfilingEngineImpl() = default;
~ProfilingEngineImpl() override = default;
PluginIntf* CreatePlugin() override;
int ReleasePlugin(PluginIntf* plugin) override;
PluginIntf *CreatePlugin() override;
int ReleasePlugin(PluginIntf *plugin) override;
};
} // namespace ascend
} // namespace device

View File

@ -35,7 +35,7 @@ using Json = nlohmann::json;
namespace mindspore {
namespace device {
namespace ascend {
ProfilingManager& ProfilingManager::GetInstance() {
ProfilingManager &ProfilingManager::GetInstance() {
static ProfilingManager inst;
return inst;
}
@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) {
}
uint64_t ProfilingManager::GetJobId() const {
const char* job_id = std::getenv("JOB_ID");
const char *job_id = std::getenv("JOB_ID");
return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0);
}
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskId_map) const {
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string> &op_taskId_map) const {
if (!IsProfiling()) {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return false;
@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size();
Msprof::Engine::ReporterData reporter_data = {};
for (const auto& iter : op_taskId_map) {
for (const auto &iter : op_taskId_map) {
auto data = iter.second + ' ' + std::to_string(iter.first) + ';';
reporter_data.deviceId = UintToInt(device_id_);
reporter_data.data = (unsigned char*)(const_cast<char*>(data.c_str()));
reporter_data.data = (unsigned char *)(const_cast<char *>(data.c_str()));
reporter_data.dataLen = data.size();
auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework"));
if (ret != 0) {
@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
return true;
}
static std::vector<std::string> Split(const std::string& str, const char delim) {
static std::vector<std::string> Split(const std::string &str, const char delim) {
std::vector<std::string> elems;
if (str.empty()) {
@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
device_id_ = device_id;
// exp: export PROFILING_MODE=true
// export PROFILING_OPTIONS=training_trace
const char* prof_options_str = std::getenv("PROFILING_OPTIONS");
const char *prof_options_str = std::getenv("PROFILING_OPTIONS");
// register Framework to profiling
int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get());
if (result != 0) {
@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return true;
}
Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter();
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
if (reporter != nullptr) {
MS_LOG(INFO) << "report data end, ret = " << reporter->Flush();
}

View File

@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
class GpuQueue {
public:
GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity);
GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity);
virtual ~GpuQueue();
void RegisterRelease(const std::function<void(void*)>& func) { host_release_ = func; }
void RegisterRelease(const std::function<void(void *)> &func) { host_release_ = func; }
inline bool IsEmpty() const { return head_ == tail_; }
inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); }
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const;
BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size);
BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const;
BlockQueueStatus_T Pop();
bool Destroy();
private:
struct NodeInfo {
std::unique_ptr<cudaEvent_t> event_;
void* host_feature_addr_;
void* host_label_addr_;
void *host_feature_addr_;
void *host_label_addr_;
};
void* buffer_;
void *buffer_;
size_t head_;
size_t tail_;
size_t feature_size_;
@ -61,10 +61,10 @@ class GpuQueue {
size_t capacity_;
cudaStream_t stream_;
std::unique_ptr<NodeInfo[]> node_info_;
std::function<void(void*)> host_release_;
std::function<void(void *)> host_release_;
GpuQueue(const GpuQueue&) = delete;
GpuQueue& operator=(const GpuQueue&) = delete;
GpuQueue(const GpuQueue &) = delete;
GpuQueue &operator=(const GpuQueue &) = delete;
};
class BlockingQueue {
@ -72,11 +72,11 @@ class BlockingQueue {
BlockingQueue() : queue_(nullptr) {}
~BlockingQueue() = default;
BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity);
void RegisterRelease(const std::function<void(void*)>& func);
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size,
BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity);
void RegisterRelease(const std::function<void(void *)> &func);
BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size,
unsigned int timeout_in_sec);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size);
BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size);
BlockQueueStatus_T Pop();
bool Destroy();

View File

@ -20,17 +20,17 @@
namespace mindspore {
namespace device {
namespace gpu {
CollectiveInitializer& CollectiveInitializer::instance() {
CollectiveInitializer &CollectiveInitializer::instance() {
static CollectiveInitializer instance = {};
return instance;
}
bool CollectiveInitializer::collective_inited() const { return collective_inited_; }
const void* CollectiveInitializer::collective_handle() const { return collective_handle_; }
const void *CollectiveInitializer::collective_handle() const { return collective_handle_; }
void CollectiveInitializer::InitCollective() {
void* handle = dlopen("libgpu_collective.so", RTLD_LAZY);
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
if (handle == nullptr) {
MS_LOG(EXCEPTION)
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not "

View File

@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() {
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
}
bool GPUDeviceManager::CreateStream(DeviceStream* stream) {
bool GPUDeviceManager::CreateStream(DeviceStream *stream) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
gpu_streams_.emplace_back(*stream);
return true;
}
const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; }
const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; }
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; }
bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const {
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const {
return CudaDriver::CopyDeviceMemToHost(dst, src, size);
}
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const {
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const {
return CudaDriver::CopyHostMemToDevice(dst, src, size);
}
} // namespace gpu

View File

@ -37,17 +37,17 @@ class GPUDeviceManager {
uint32_t cur_device_id() const;
bool is_device_id_init() const;
bool CreateStream(DeviceStream* stream);
bool SyncStream(const DeviceStream& stream) const;
const DeviceStream& default_stream() const;
bool CreateStream(DeviceStream *stream);
bool SyncStream(const DeviceStream &stream) const;
const DeviceStream &default_stream() const;
const cudnnHandle_t& GetCudnnHandle() const;
const cublasHandle_t& GetCublasHandle() const;
const cudnnHandle_t &GetCudnnHandle() const;
const cublasHandle_t &GetCublasHandle() const;
bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const;
bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const;
static GPUDeviceManager& GetInstance() {
static GPUDeviceManager &GetInstance() {
static GPUDeviceManager instance;
return instance;
}
@ -55,8 +55,8 @@ class GPUDeviceManager {
private:
GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {}
~GPUDeviceManager() = default;
GPUDeviceManager(const GPUDeviceManager&) = delete;
GPUDeviceManager& operator=(const GPUDeviceManager&) = delete;
GPUDeviceManager(const GPUDeviceManager &) = delete;
GPUDeviceManager &operator=(const GPUDeviceManager &) = delete;
// default CUDA stream used for all the kernels.
DeviceStream default_stream_{nullptr};

View File

@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() {
return true;
}
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) {
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) {
auto alloc_size = AllocDeviceMem(size, addr);
buffer_q_addr_ = *addr;
// Buffer queue needs to ensure that the alloc_size and size is equal.
return (alloc_size == size) ? true : false;
}
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (size == 0) {
MS_LOG(EXCEPTION) << "The memory alloc size is 0.";
}
@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return alloc_size;
}
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); }
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); }
size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); }

View File

@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit {
~GPUMemoryAllocator() override = default;
bool Init();
bool Finalize();
bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr);
bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr);
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override;
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
size_t free_mem_size() override;
size_t total_mem_size() override;
static GPUMemoryAllocator& GetInstance() {
static GPUMemoryAllocator &GetInstance() {
static GPUMemoryAllocator instance;
return instance;
}
private:
GPUMemoryAllocator() = default;
GPUMemoryAllocator(const GPUMemoryAllocator&) = delete;
GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete;
GPUMemoryAllocator(const GPUMemoryAllocator &) = delete;
GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete;
// Used to track address of data buffer queue.
DeviceMemPtr buffer_q_addr_{nullptr};

View File

@ -33,8 +33,8 @@ namespace gpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using mindspore::kernel::KernelBuildInfo;
namespace {
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info,
const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) {
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(selected_kernel_info);
MS_EXCEPTION_IF_NULL(alternative_kernel_info);
size_t selected_input_num = selected_kernel_info->GetInputNum();
@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_
return true;
}
std::string SupportedTypeList(const CNodePtr& kernel_node) {
std::string SupportedTypeList(const CNodePtr &kernel_node) {
std::string supported_type_lists =
kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node));
if (!supported_type_lists.empty()) {
@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) {
return supported_type_lists;
}
bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) {
bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(selected_kernel_info);
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
}
bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&](const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info) {
[&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
});
if (!match) {
@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
return true;
}
void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, const CNodePtr& kernel_node) {
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = kernel_node->input(input_index + 1);
@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co
}
} // namespace
void SetKernelInfo(const CNodePtr& kernel_node) {
void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =

View File

@ -27,7 +27,7 @@
namespace mindspore {
namespace device {
namespace gpu {
void SetKernelInfo(const CNodePtr& apply_kernel_ptr);
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
class KernelAttr {
public:
@ -35,24 +35,24 @@ class KernelAttr {
KernelAttr() : all_same_(false) {}
~KernelAttr() = default;
KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) {
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
input_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) {
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
output_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr& AddAllSameAttr(const bool& all_same) {
KernelAttr &AddAllSameAttr(const bool &all_same) {
all_same_ = all_same;
return *this;
}
const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; }
const bool& GetAllSame() const { return all_same_; }
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
const bool &GetAllSame() const { return all_same_; }
size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); }

View File

@ -24,7 +24,7 @@
namespace mindspore {
struct TypeIdManager* TypeIdManager::Get() {
struct TypeIdManager *TypeIdManager::Get() {
static TypeIdManager manager;
return &manager;
}

View File

@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode*>(this)->shared_from_base<AnfNode>()->debug_info());
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
CNode::CNode(const std::vector<AnfNodePtr>& inputs, const FuncGraphPtr& func_graph)
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
// Check if CNode is an apply with the specific Primitive.
bool CNode::IsApply(const PrimitivePtr& value) const {
bool CNode::IsApply(const PrimitivePtr &value) const {
if (value == nullptr) {
return false;
}
@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const {
return false;
}
void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; }
void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; }
std::string CNode::DebugString(int recursive_level) const {
std::ostringstream buffer;
@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const {
buffer << ToString() << "{";
bool is_first_node = true;
int idx = 0;
for (auto& node : inputs_) {
for (auto &node : inputs_) {
MS_EXCEPTION_IF_NULL(node);
if (is_first_node) {
is_first_node = false;
@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str();
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) {
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() {
return fullname_with_scope_;
}
void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base<Parameter>()); }
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) {
@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
return false;
}
PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) {
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) {
return "";
}
bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
if (IsValueNode<Primitive>(node)) {
PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(value);
@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
}
namespace id_generator {
static std::unordered_map<std::string, int> node_ids;
std::string get_id(const AnfNodePtr& node) {
std::string get_id(const AnfNodePtr &node) {
auto type_name = node->type_name();
if (node_ids.find(type_name) == node_ids.end()) {
node_ids[type_name] = 0;

View File

@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
class Base : public std::enable_shared_from_this<Base> {
public:
constexpr Base() = default;
Base(const Base& other) : std::enable_shared_from_this<Base>(other) {}
virtual bool operator==(const Base& rhs) {
Base(const Base &other) : std::enable_shared_from_this<Base>(other) {}
virtual bool operator==(const Base &rhs) {
if (this == &rhs) {
return true;
}
return false;
}
virtual Base& operator=(const Base&) { return *this; }
virtual Base &operator=(const Base &) { return *this; }
virtual ~Base() = default;
virtual std::size_t hash() const { return tid(); }
virtual std::string ToString() const { return type_name(); }
@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> {
virtual const bool IsFromTypeId(uint32_t tid) const;
virtual std::string type_name() const { return "Base"; }
static uint32_t GetTypeId(const char* const type_key);
static uint32_t GetTypeId(const char *const type_key);
virtual uint32_t tid() const {
static const uint32_t tid = GetTypeId(typeid(Base).name());
return tid;
}
template <typename T,
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type* = nullptr>
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr>
inline bool isa() const {
static const uint32_t tid = GetTypeId(typeid(T).name());
return this->IsFromTypeId(tid);
@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>;
using BaseWeakPtr = std::weak_ptr<Base>;
template <typename T, typename U>
inline T* cast(U* source) {
inline T *cast(U *source) {
if (source != nullptr && source->template isa<T>()) {
return static_cast<T*>(source);
return static_cast<T *>(source);
} else {
return nullptr;
}
@ -100,7 +100,7 @@ inline T* cast(U* source) {
template <
typename T, typename U,
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type* = nullptr>
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr>
inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) {
if (r != nullptr && r->template isa<T>()) {
return std::static_pointer_cast<T>(r);
@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager {
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> map;
static TypeIdManager* Get();
static TypeIdManager *Get();
TypeIdManager() : mutex(), type_counter(0), map() {}
};
} // namespace mindspore

View File

@ -48,11 +48,11 @@ std::string Keyword::ToString() const {
return buffer.str();
}
bool Keyword::operator==(const Type& other) const {
bool Keyword::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const auto& other_keyword = static_cast<const Keyword&>(other);
const auto &other_keyword = static_cast<const Keyword &>(other);
return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_);
}
@ -87,11 +87,11 @@ std::string Slice::ToString() const {
return buffer.str();
}
bool Slice::operator==(const Type& other) const {
bool Slice::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_slice = static_cast<const Slice&>(other);
auto other_slice = static_cast<const Slice &>(other);
return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_);
}
@ -122,11 +122,11 @@ std::string TensorType::DumpText() const {
}
}
bool TensorType::operator==(const Type& other) const {
bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const TensorType&>(other).element_type_;
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) {
retval_ = nullptr;
}
Function::Function(const std::vector<TypePtr>& args, const TypePtr retval)
Function::Function(const std::vector<TypePtr> &args, const TypePtr retval)
: Object(kObjectTypeFunction, false), args_(args), retval_(retval) {}
TypePtr Function::DeepCopy() const {
@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const {
TypePtrList args;
TypePtr retval = nullptr;
(void)std::transform(args_.begin(), args_.end(), std::back_inserter(args),
[](const TypePtr& arg) { return arg->DeepCopy(); });
[](const TypePtr &arg) { return arg->DeepCopy(); });
if (retval_ != nullptr) {
retval = retval_->DeepCopy();
}
@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const {
}
}
bool Function::operator==(const Type& other) const {
bool Function::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const auto& other_function = static_cast<const Function&>(other);
const auto &other_function = static_cast<const Function &>(other);
if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) {
if (*retval_ != *other_function.retval_) {
return false;
@ -188,7 +188,7 @@ std::string Function::ToString() const {
} else {
buffer << "Func[(";
bool begin = true;
for (auto& attr : args_) {
for (auto &attr : args_) {
if (!begin) {
buffer << ", ";
} else {
@ -242,34 +242,34 @@ std::string JTagged::DumpText() const {
return buffer.str();
}
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem) {
std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem) {
MS_EXCEPTION_IF_NULL(problem);
os << problem->ToString();
return os;
}
std::size_t TypeHasher::operator()(TypePtr const& type) const {
std::size_t TypeHasher::operator()(TypePtr const &type) const {
MS_EXCEPTION_IF_NULL(type);
std::size_t hash = std::hash<size_t>()(type->type_id());
return hash;
}
std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const {
std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
std::size_t hash_sum = 0;
for (auto& type : type_list) {
for (auto &type : type_list) {
auto type_id = static_cast<std::size_t>(type->type_id());
hash_sum = hash_combine(hash_sum, type_id);
}
return hash_sum;
}
bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const {
bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return t1->type_id() == t2->type_id();
}
bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const {
bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
if (lhs.size() != rhs.size()) {
return false;
}
@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) {
namespace {
template <typename T>
TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) {
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
TypePtr type = nullptr;
if (type_name == num_type_name) {
type = std::make_shared<T>();
@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_
}
auto bits = std::stoi(type_name.substr(num_type_name.size()));
type = std::make_shared<T>(bits);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what();
}
}
return type;
}
std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
std::vector<TypePtr> types;
if (type_names.length() == 0) {
return types;
@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
return types;
}
TypePtr TensorStrToType(const std::string& type_name) {
TypePtr TensorStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tensor") {
type = std::make_shared<TensorType>();
@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return nullptr;
}
type = std::make_shared<TensorType>(element_type);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return type;
}
TypePtr ListStrToType(const std::string& type_name) {
TypePtr ListStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "List") {
type = std::make_shared<List>();
@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; });
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<List>(element_types);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) {
return type;
}
TypePtr TupleStrToType(const std::string& type_name) {
TypePtr TupleStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tuple") {
type = std::make_shared<Tuple>();
@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; });
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<Tuple>(element_types);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr FunctionStrToType(const std::string& type_name) {
TypePtr FunctionStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Function") {
@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) {
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
TypePtr retval = StringToType(str_retval);
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; });
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
if (retval == nullptr || wrong) {
return nullptr;
}
type = std::make_shared<Function>(args_type, retval);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) {
}
} // namespace
TypePtr StringToType(const std::string& type_name) {
TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>();
@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) {
return type;
}
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr.";
return false;
@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
}
}
bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
MS_EXCEPTION_IF_NULL(t1);
if (t1->type_id() == kTypeUnknown) {
return false;
@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
}
REGISTER_PYBIND_DEFINE(
typing, ([](py::module* const m) {
typing, ([](py::module *const m) {
auto m_sub = m->def_submodule("typing", "submodule for dtype");
py::enum_<TypeId>(m_sub, "TypeId");
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
(void)m_sub.def("load_type", &TypeIdToType, "load type");
(void)m_sub.def(
"dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type");
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__",
[](const TypePtr& t1, const TypePtr& t2) {
[](const TypePtr &t1, const TypePtr &t2) {
if (t1 != nullptr && t2 != nullptr) {
return *t1 == *t2;
}
@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE(
.def("__hash__", &Type::hash)
.def("__str__", &Type::ToString)
.def("__repr__", &Type::ReprString)
.def("__deepcopy__", [](const TypePtr& t, py::dict) {
.def("__deepcopy__", [](const TypePtr &t, py::dict) {
if (t == nullptr) {
return static_cast<TypePtr>(nullptr);
}
@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
.def(py::init())
.def(py::pickle(
[](const Bool&) { // __getstate__
[](const Bool &) { // __getstate__
return py::make_tuple();
},
[](const py::tuple&) { // __setstate__
[](const py::tuple &) { // __setstate__
return std::make_shared<Bool>();
}));
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Int& t) { // __getstate__
[](const Int &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const UInt& t) { // __getstate__
[](const UInt &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Float& t) { // __getstate__
[](const Float &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init<TypePtr>(), py::arg("element"))
.def("element_type", &TensorType::element)
.def(py::pickle(
[](const TensorType& t) { // __getstate__
[](const TensorType &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}

View File

@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>;
class Keyword : public Object {
public:
Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
~Keyword() override = default;
MS_DECLARE_PARENT(Keyword, Object)
@ -70,7 +70,7 @@ class Keyword : public Object {
std::string ToString() const override;
std::string DumpText() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
std::string GetKey() const { return key_; }
TypePtr GetValue() const { return value_; }
@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>;
class Slice : public Object {
public:
Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step)
Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step)
: Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
~Slice() override = default;
@ -95,7 +95,7 @@ class Slice : public Object {
std::string ToString() const override;
std::string DumpText() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr get_start() const { return start_; }
TypePtr get_stop() const { return stop_; }
@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>;
class TensorType : public Object {
public:
TensorType() : Object(kObjectTypeTensorType) {}
explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr& element_type) { element_type_ = element_type; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override { return "tensor"; }
std::string DumpText() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
class Function : public Object {
public:
Function();
Function(const std::vector<TypePtr>& args, const TypePtr retval);
Function(const std::vector<TypePtr> &args, const TypePtr retval);
~Function() override = default;
MS_DECLARE_PARENT(Function, Object)
@ -141,11 +141,11 @@ class Function : public Object {
// Add temporarily for return abstraction to avoid type checking.
bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
const std::vector<TypePtr>& args() const { return args_; }
const TypePtr& retval() const { return retval_; }
const std::vector<TypePtr> &args() const { return args_; }
const TypePtr &retval() const { return retval_; }
TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
std::string ToString() const override;
std::string ToReprString() const override { return "function"; }
@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>;
class JTagged : public Object {
public:
JTagged() : Object(kObjectTypeJTagged) {}
explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
~JTagged() override = default;
MS_DECLARE_PARENT(JTagged, Object)
@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>;
class Problem : public Type {
public:
Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {}
explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {}
~Problem() override = default;
MS_DECLARE_PARENT(Problem, Type)
@ -222,7 +222,7 @@ class Problem : public Type {
std::string ToString() const override { return kind_.name(); }
std::string DumpText() const override { return "ProblemType"; }
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem);
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem);
private:
Named kind_;
@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>;
// helper template
template <class T>
TypePtr Clone(const T& t) {
TypePtr Clone(const T &t) {
return t.Clone();
}
TypePtr StringToType(const std::string& type_name);
TypePtr StringToType(const std::string &type_name);
// Judge whether x is predicate or is a subclass of predicate.
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type);
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
// Whether t1 is identity or a subclass of t2.
bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr);
bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
struct TypeHasher {
std::size_t operator()(TypePtr const& type) const;
std::size_t operator()(TypePtr const &type) const;
};
struct TypeListHasher {
std::size_t operator()(const TypePtrList& type_list) const;
std::size_t operator()(const TypePtrList &type_list) const;
};
struct TypeEqual {
bool operator()(TypePtr const& t1, TypePtr const& t2) const;
bool operator()(TypePtr const &t1, TypePtr const &t2) const;
};
struct TypeListEqual {
bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const;
bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const;
};
extern const TypePtr kTypeExternal;

View File

@ -24,7 +24,7 @@
#include "pybind_api/export_flags.h"
namespace mindspore {
static std::string DumpTypeVector(const std::vector<TypePtr>& elements, bool is_dumptext) {
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
std::ostringstream oss;
bool begin = true;
int cnt = 0;
@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const {
} else {
TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); });
[](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<List>(elements);
return copy;
}
@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const {
return elements_[dim];
}
bool List::operator==(const Type& other) const {
bool List::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const List& other_list = static_cast<const List&>(other);
const List &other_list = static_cast<const List &>(other);
if (elements_.size() != other_list.elements_.size()) {
return false;
}
@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const {
return true;
}
Class::Class(const Named& tag, const ClassAttrVector& attributes,
const std::unordered_map<std::string, ValuePtr>& methods)
Class::Class(const Named &tag, const ClassAttrVector &attributes,
const std::unordered_map<std::string, ValuePtr> &methods)
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
std::string List::ToString() const {
@ -122,7 +122,7 @@ std::string List::DumpText() const {
return buffer.str();
}
bool Class::operator==(const Type& other) const {
bool Class::operator==(const Type &other) const {
// Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj.
return &other == this;
}
@ -143,7 +143,7 @@ std::string Class::ToString() const {
} else {
bool begin = true;
buffer << "cls." << tag_ << "[";
for (auto& attr : attributes_) {
for (auto &attr : attributes_) {
if (!begin) {
buffer << ", ";
} else {
@ -163,7 +163,7 @@ std::string Class::DumpText() const {
} else {
bool begin = true;
buffer << "Cls." << tag_ << "[";
for (auto& attr : attributes_) {
for (auto &attr : attributes_) {
if (!begin) {
buffer << ", ";
} else {
@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const {
} else {
TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); });
[](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<Tuple>(elements);
return copy;
}
}
bool Tuple::operator==(const Type& other) const {
bool Tuple::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_tuple = static_cast<const Tuple&>(other);
auto other_tuple = static_cast<const Tuple &>(other);
if (elements_.size() != other_tuple.elements_.size()) {
return false;
}
@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const {
std::vector<std::pair<std::string, TypePtr>> kv;
(void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, TypePtr>& item) { return std::make_pair(item.first, item.second->DeepCopy()); });
[](const std::pair<std::string, TypePtr> &item) { return std::make_pair(item.first, item.second->DeepCopy()); });
return std::make_shared<Dictionary>(kv);
}
}
@ -259,7 +259,7 @@ std::string Dictionary::ToString() const {
std::ostringstream buffer;
std::vector<std::string> keys;
std::vector<TypePtr> values;
for (const auto& kv : key_values_) {
for (const auto &kv : key_values_) {
keys.push_back(kv.first);
values.push_back(kv.second);
}
@ -276,12 +276,12 @@ std::string Dictionary::ToString() const {
std::string Dictionary::DumpText() const { return ToString(); }
bool Dictionary::operator==(const mindspore::Type& other) const {
bool Dictionary::operator==(const mindspore::Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const auto& other_dict = static_cast<const Dictionary&>(other);
const auto &other_dict = static_cast<const Dictionary &>(other);
if (key_values_.size() != other_dict.key_values_.size()) {
return false;
}

View File

@ -40,10 +40,10 @@ namespace mindspore {
class List : public Object {
public:
List() : Object(kObjectTypeList) {}
List(const std::initializer_list<TypePtr>& objs)
List(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy;
explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {}
explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {}
~List() override {}
MS_DECLARE_PARENT(List, Object)
@ -51,7 +51,7 @@ class List : public Object {
TypeId generic_type_id() const override { return kObjectTypeList; }
TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
std::size_t size() const { return elements_.size(); }
TypePtrList elements() const { return elements_; }
std::string ToString() const override;
@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>;
class Class : public Object {
public:
Class() : Object(kObjectTypeClass), tag_(Named("Class")) {}
Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map<std::string, ValuePtr>& methods);
Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map<std::string, ValuePtr> &methods);
~Class() override {}
MS_DECLARE_PARENT(Class, Object)
TypeId generic_type_id() const override { return kObjectTypeClass; }
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string DumpText() const override;
void set_value(const std::unordered_map<std::string, ValuePtr>& v) { attributes_value_ = v; }
void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; }
Named tag() { return tag_; }
std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; }
std::unordered_map<std::string, ValuePtr> methods() { return methods_; }
ClassAttrVector& GetAttributes() { return attributes_; }
ClassAttrVector &GetAttributes() { return attributes_; }
ClassAttrVector attributes_;
@ -99,11 +99,11 @@ class Tuple : public Object {
public:
Tuple() : Object(kObjectTypeTuple) {}
// usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)};
Tuple(const std::initializer_list<TypePtr>& objs)
Tuple(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy
explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
~Tuple() override {}
MS_DECLARE_PARENT(Tuple, Object)
@ -115,7 +115,7 @@ class Tuple : public Object {
std::string ToReprString() const override { return "tuple_"; }
std::string DumpText() const override;
const TypePtr operator[](size_t dim) const;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtrList elements() const { return elements_; }
std::size_t size() const { return elements_.size(); }
@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>;
class Dictionary : public Object {
public:
Dictionary() : Object(kObjectTypeDictionary) {}
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>>& key_values)
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>> &key_values)
: Object(kObjectTypeDictionary, false), key_values_(key_values) {}
~Dictionary() override {}
@ -136,7 +136,7 @@ class Dictionary : public Object {
TypeId generic_type_id() const override { return kObjectTypeDictionary; }
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string DumpText() const override;

View File

@ -24,11 +24,11 @@
#include "pybind_api/export_flags.h"
namespace mindspore {
bool Number::operator==(const Type& other) const {
bool Number::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_number = static_cast<const Number&>(other);
auto other_number = static_cast<const Number &>(other);
return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_));
}

View File

@ -49,12 +49,12 @@ class Number : public Object {
TypeId type_id() const override { return number_type_; }
TypeId generic_type_id() const override { return kObjectTypeNumber; }
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override { return std::make_shared<Number>(); }
std::string ToString() const override { return "Number"; }
std::string ToReprString() const override { return "number"; }
std::string DumpText() const override { return "Number"; }
std::string GetTypeName(const std::string& type_name) const {
std::string GetTypeName(const std::string &type_name) const {
std::ostringstream oss;
oss << type_name;
if (nbits() != 0) {

View File

@ -51,7 +51,7 @@ class RefKeyType : public Object {
class RefType : public Object {
public:
RefType() : Object(kObjectTypeRef) {}
RefType(const TypePtr& subtype, const TypePtr& subtype_origin)
RefType(const TypePtr &subtype, const TypePtr &subtype_origin)
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
~RefType() override {}
MS_DECLARE_PARENT(RefType, Object)

View File

@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) {
}
}
const char* MetaIdLabel(const TypeId& v) {
const char *MetaIdLabel(const TypeId &v) {
switch (v) {
case kTypeUnknown:
return "kTypeUnknown";
@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) {
}
}
const char* ObjectIdLabel(const TypeId& v) {
const char *ObjectIdLabel(const TypeId &v) {
switch (v) {
case kObjectTypeNumber:
return "kObjectTypeNumber";
@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) {
}
}
const char* NumberIdLabel(const TypeId& v) {
const char *NumberIdLabel(const TypeId &v) {
switch (v) {
case kNumberTypeBool:
return "kNumberTypeBool";
@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) {
}
}
const char* TypeIdLabel(const TypeId& v) {
const char *TypeIdLabel(const TypeId &v) {
if (v < kMetaTypeEnd) {
return MetaIdLabel(v);
} else {
@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) {
}
}
bool IsSameObjectType(const Type& lhs, const Type& rhs) {
bool IsSameObjectType(const Type &lhs, const Type &rhs) {
if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
return false;
}
return lhs.object_type() == rhs.object_type();
}
size_t GetTypeByte(const TypePtr& type_ptr) {
size_t GetTypeByte(const TypePtr &type_ptr) {
if (type_ptr && type_ptr->isa<Number>()) {
auto number = dyn_cast<Number>(type_ptr);
if (!number) {
@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) {
}
}
bool Type::operator==(const Value& other) const {
bool Type::operator==(const Value &other) const {
if (other.isa<Type>()) {
auto other_type = static_cast<const Type*>(&other);
auto other_type = static_cast<const Type *>(&other);
return *this == *other_type;
} else {
return false;
@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() {
return ptr;
}
std::ostream& operator<<(std::ostream& os, const Type& type) {
std::ostream &operator<<(std::ostream &os, const Type &type) {
os << type.ToString();
return os;
}
std::ostream& operator<<(std::ostream& os, const TypePtr type) {
std::ostream &operator<<(std::ostream &os, const TypePtr type) {
os << type->ToString();
return os;
}
@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const {
return false;
}
std::ostream& operator<<(std::ostream& os, const Object& obj) {
std::ostream &operator<<(std::ostream &os, const Object &obj) {
os << obj.ToString();
return os;
}
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj) {
std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
os << obj->ToString();
return os;
}
std::ostream& operator<<(std::ostream& os, const TypePtrList& types) {
std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
os << "[";
for (size_t i = 0; i < types.size(); ++i) {
if (i > 0) {

View File

@ -95,10 +95,10 @@ enum TypeId : int {
TypeId IntBitsToTypeId(const int nbits);
TypeId UIntBitsToTypeId(const int nbits);
TypeId FloatBitsToTypeId(const int nbits);
const char* TypeIdLabel(const TypeId& v);
const char *TypeIdLabel(const TypeId &v);
TypeId NormalizeTypeId(const TypeId type_id);
bool IsSameObjectType(const Type& lhs, const Type& rhs);
size_t GetTypeByte(const TypePtr& type_ptr);
bool IsSameObjectType(const Type &lhs, const Type &rhs);
size_t GetTypeByte(const TypePtr &type_ptr);
// Base class for all types
// forward declaration.
@ -110,14 +110,14 @@ class Type : public Value {
~Type() override = default;
MS_DECLARE_PARENT(Type, Value)
bool operator==(const Value& other) const override;
bool operator==(const Value &other) const override;
TypeId meta_type() const { return meta_type_; }
virtual TypeId type_id() const { return meta_type_; }
virtual TypeId generic_type_id() const { return kMetaTypeType; }
virtual bool operator!=(const Type& other) const { return !(*this == other); }
virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); }
virtual bool operator!=(const Type &other) const { return !(*this == other); }
virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); }
virtual bool equal(const TypePtr other) const { return *this == *other; }
virtual TypeId object_type() const { return kTypeUnknown; }
@ -134,8 +134,8 @@ class Type : public Value {
bool IsUnknown() const { return (meta_type_ == kMetaTypeType); }
bool IsGeneric() const { return is_generic_; }
abstract::AbstractBasePtr ToAbstract() override;
friend std::ostream& operator<<(std::ostream& os, const Type& type);
friend std::ostream& operator<<(std::ostream& os, const TypePtr type);
friend std::ostream &operator<<(std::ostream &os, const Type &type);
friend std::ostream &operator<<(std::ostream &os, const TypePtr type);
const bool parse_info_ = true;
@ -163,14 +163,14 @@ class Object : public Type {
bool equal(const TypePtr other) const override;
std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); }
friend std::ostream& operator<<(std::ostream& os, const Object& obj);
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj);
friend std::ostream &operator<<(std::ostream &os, const Object &obj);
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj);
private:
const TypeId object_type_;
};
std::ostream& operator<<(std::ostream& os, const TypePtrList& types);
std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_

View File

@ -61,7 +61,7 @@ FuncGraph::FuncGraph()
AbstractFunctionPtr FuncGraph::abstract() {
AbstractBasePtrList args_spec_list;
for (auto& p : parameters_) {
for (auto &p : parameters_) {
MS_EXCEPTION_IF_NULL(p);
if (p->abstract() == nullptr) {
MS_LOG(ERROR) << "Error!!";
@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() {
return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
}
abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) {
abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) {
AnalysisContextPtr temp_context = context;
if (temp_context == nullptr) {
temp_context = abstract::AnalysisContext::DummyContext();
@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const {
}
}
void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) {
void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
if (force_new_ret || return_ == nullptr) {
std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() {
return p;
}
void FuncGraph::add_parameter(const ParameterPtr& p) {
void FuncGraph::add_parameter(const ParameterPtr &p) {
if (manager_.lock()) {
std::vector<AnfNodePtr> new_params = parameters_;
new_params.push_back(p);
@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) {
}
}
ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) {
ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_graph);
p->set_name(name);
@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) {
return p;
}
bool FuncGraph::has_flag(const std::string& flag) {
bool FuncGraph::has_flag(const std::string &flag) {
if (flags_.count(flag)) {
return flags_[flag];
}
return false;
}
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) {
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
order_.push_back(cnode);
@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) {
return cnode;
}
CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, const ScopePtr& scope) {
CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope) {
CNodePtr app = NewCNode(inputs);
app->set_scope(scope);
return app;
@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, con
void FuncGraph::DumpCNodeList() {
MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
for (const auto& cnode : order_) {
for (const auto &cnode : order_) {
MS_LOG(INFO) << cnode->DebugString();
}
}
std::string FuncGraph::ToString() const {
return mindspore::label_manage::Label(const_cast<FuncGraph*>(this)->shared_from_base<FuncGraph>()->debug_info());
return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info());
}
GraphDebugInfoPtr FuncGraph::debug_info() {
@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
return this->debug_info_;
}
const AnfNodeSet& FuncGraph::nodes() {
const AnfNodeSet &FuncGraph::nodes() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& nodes = mng->nodes();
auto &nodes = mng->nodes();
return nodes[shared_from_base<FuncGraph>()];
}
const AnfNodeCounterMap& FuncGraph::value_nodes() {
const AnfNodeCounterMap &FuncGraph::value_nodes() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& cts = mng->valuenodes();
auto &cts = mng->valuenodes();
return cts[shared_from_base<FuncGraph>()];
}
const AnfNodeCounterMap& FuncGraph::free_variables_direct() {
const AnfNodeCounterMap &FuncGraph::free_variables_direct() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& fv_direct = mng->free_variables_direct();
auto &fv_direct = mng->free_variables_direct();
return fv_direct[shared_from_base<FuncGraph>()];
}
const BaseRefCounterMap& FuncGraph::free_variables_total() {
const BaseRefCounterMap &FuncGraph::free_variables_total() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& fv_total = mng->free_variables_total();
auto &fv_total = mng->free_variables_total();
return fv_total[shared_from_base<FuncGraph>()];
}
std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
std::vector<AnfNodePtr> nodes;
const auto& fv_total = this->free_variables_total();
for (auto& p : fv_total) {
const auto &fv_total = this->free_variables_total();
for (auto &p : fv_total) {
auto key = p.first;
if (utils::isa<AnfNodePtr>(key)) {
nodes.push_back(utils::cast<AnfNodePtr>(key));
@ -238,8 +238,8 @@ std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
std::vector<FuncGraphPtr> func_graphs;
const auto& fv_total = this->free_variables_total();
for (auto& p : fv_total) {
const auto &fv_total = this->free_variables_total();
for (auto &p : fv_total) {
auto key = p.first;
if (utils::isa<FuncGraphPtr>(key)) {
func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
@ -249,31 +249,31 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return func_graphs;
}
const FuncGraphCounterMap& FuncGraph::func_graphs_used() {
const FuncGraphCounterMap &FuncGraph::func_graphs_used() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& used = mng->func_graphs_used();
auto &used = mng->func_graphs_used();
return used[shared_from_base<FuncGraph>()];
}
const FuncGraphSet& FuncGraph::func_graphs_used_total() {
const FuncGraphSet &FuncGraph::func_graphs_used_total() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
return used;
}
const FuncGraphCounterMap& FuncGraph::func_graph_users() {
const FuncGraphCounterMap &FuncGraph::func_graph_users() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& users = mng->func_graph_users();
auto &users = mng->func_graph_users();
return users[shared_from_base<FuncGraph>()];
}
const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() {
const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto& users = mng->func_graph_user_cnodes();
auto &users = mng->func_graph_user_cnodes();
return users[shared_from_base<FuncGraph>()];
}
@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() {
return mng->parent(shared_from_base<FuncGraph>());
}
const FuncGraphSet& FuncGraph::children() {
const FuncGraphSet &FuncGraph::children() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
return mng->children(shared_from_base<FuncGraph>());
}
const FuncGraphSet& FuncGraph::scope() {
const FuncGraphSet &FuncGraph::scope() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
return mng->scopes(shared_from_base<FuncGraph>());
@ -312,9 +312,9 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
return mng->recursive_graphs(shared_from_base<FuncGraph>());
}
void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); }
void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); }
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) {
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
auto itr = this->parameter_default_value_.find(name);
if (itr == parameter_default_value_.end()) {
return nullptr;
@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) {
}
// set the default values
void FuncGraph::SetDefaultValues(const std::vector<std::string>& name_list, const std::vector<AnfNodePtr>& value_list) {
void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
auto all_is_null = std::all_of(value_list.begin(), value_list.end(),
[](const AnfNodePtr& node) { return IsValueNode<NullObj>(node); });
[](const AnfNodePtr &node) { return IsValueNode<NullObj>(node); });
if (value_list.empty()) {
all_is_null = true;
}
@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
size_t FuncGraph::GetDefaultValueCount() {
int null_count =
std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
[](const std::pair<std::string, AnfNodePtr>& pair) { return IsValueNode<NullObj>(pair.second); });
[](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<NullObj>(pair.second); });
return parameter_default_value_.size() - IntToSize(null_count);
}
@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const {
return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
}
AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) {
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
for (size_t i = 0; i < parameters_.size(); ++i) {
MS_EXCEPTION_IF_NULL(parameters_[i]);
auto param_cast = parameters_[i]->cast<ParameterPtr>();
@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) {
return nullptr;
}
void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph,
std::vector<AnfNodePtr>* specialized_parameter_list,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, int variable_args_count,
void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
std::vector<AnfNodePtr> *specialized_parameter_list,
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count,
int pos_args_input_count) {
// if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
if (specialized_graph->has_vararg()) {
@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph,
}
}
void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
std::vector<AnfNodePtr>* specialized_parameter_list,
const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) {
void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
std::vector<AnfNodePtr> *specialized_parameter_list,
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
for (const auto& kwarg : kwarg_list) {
for (const auto &kwarg : kwarg_list) {
MS_EXCEPTION_IF_NULL(kwarg);
std::string kw_param_name = kwarg->get_key();
MS_EXCEPTION_IF_NULL(specialized_graph);
@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
MS_EXCEPTION_IF_NULL(specialized_parameter_list);
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
[param_name](const AnfNodePtr& node) {
[param_name](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto param = node->cast<ParameterPtr>();
return param != nullptr && param->name() == param_name;
@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph,
GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes);
}
void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes,
const std::vector<AnfNodePtr>& kwarg_keys_tuple_nodes,
const std::vector<AnfNodePtr>& kwarg_values_tuple_nodes) {
void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes,
const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) {
if (has_kwarg()) {
MS_EXCEPTION_IF_NULL(specialized_graph);
TraceManager::DebugTrace(
@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph,
}
}
bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list) {
bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
// if the function does not have any vararg/kwarg/kwonly/default value/kw args input
// return the original graph
if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>&
return true;
}
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph,
const std::vector<AnfNodePtr>& specialized_parameter_list,
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) {
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
const std::vector<AnfNodePtr> &specialized_parameter_list,
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) {
MS_EXCEPTION_IF_NULL(specialized_graph);
for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
auto param_node = specialized_graph->parameters()[i];
@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph,
}
}
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
size_t arguments_count = args_spec_list.size();
for (const auto& arg : args_spec_list) {
for (const auto &arg : args_spec_list) {
// if it is a keyword argument
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<abstract::AbstractKeywordArg>()) {
@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list)
MS_EXCEPTION_IF_NULL(specialized_graph);
auto params = specialized_graph->parameters();
(void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(),
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; });
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; });
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
auto tr = manager->Transact();
for (auto& node_pair : repl_nodes) {
for (auto &node_pair : repl_nodes) {
MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
<< node_pair.second->DebugString();
(void)tr.Replace(node_pair.first, node_pair.second);
@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list)
return specialized_graph;
}
void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); }
void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
@ -651,7 +651,7 @@ std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
std::list<CNodePtr> cnodes;
auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph);
for (const auto& node : nodes) {
for (const auto &node : nodes) {
auto cnode = dyn_cast<CNode>(node);
if (cnode) {
cnodes.push_back(cnode);
@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() {
}
}
void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) {
void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) {
if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) {
order_.remove(n->cast<CNodePtr>());
MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list.";
@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
MS_LOG(DEBUG) << "Check graph " << ToString();
for (auto it = order_.begin(); it != order_.end(); (void)it++) {
for (const auto& input_node : (*it)->inputs()) {
for (const auto &input_node : (*it)->inputs()) {
if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) {
// Need to reorder the wrong order node.
auto found = std::find(order_.begin(), it, input_node);
@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() {
}
auto mng = manager_.lock();
if (mng != nullptr) {
const auto& nodes = mng->nodes()[shared_from_base<FuncGraph>()];
const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()];
if (nodes.size() != (order_.size() + parameters_.size())) {
DumpCNodeList();
MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() {
const char kPrimHasEffect[] = "_side_effect_flag";
bool FuncGraph::HasEffect(const CNodePtr& cnode) {
bool FuncGraph::HasEffect(const CNodePtr &cnode) {
auto prim = GetCNodePrimitive(cnode);
if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) {
auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>();
@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) {
return false;
}
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& segment) {
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
for (const auto& node : segment) {
for (const auto &node : segment) {
if (roots->size() == 1) {
return roots;
}
@ -757,9 +757,9 @@ std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& seg
return roots;
}
std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr>& segment) {
std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
for (const auto& node : segment) {
for (const auto &node : segment) {
if (nodes->size() == 1) {
return nodes;
}
@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
std::list<AnfNodePtr> depends_order;
std::vector<CNodePtr> segment;
for (const auto& cnode : order_) {
for (const auto &cnode : order_) {
if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) {
continue;
}
@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
}
}
void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr>& depend_inputs) {
void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
auto old_ret = output();
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret};
(void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end());

View File

@ -26,29 +26,29 @@
// namespace to support intermediate representation definition
namespace mindspore {
Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation)
Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
: clone_all_valuenodes_(clone_all_valuenodes),
clone_all_child_graphs_(clone_all_child_graphs),
clone_all_used_graphs_(clone_all_used_graphs),
relation_(relation),
target_relation_(target_relation == nullptr ? relation : target_relation) {
for (auto& func_graph : func_graphs) {
for (auto &func_graph : func_graphs) {
AddClone(func_graph);
}
scope_ = kDefaultScope;
type_ = kBasic;
}
void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph,
const AnfNodePtrList& params, CloneType type) {
void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList &params, CloneType type) {
if (func_graph != nullptr) {
todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params});
type_ = type;
}
}
void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node);
if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
return;
@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
}
}
void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) {
void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_);
@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target,
TraceManager::EndTrace();
}
void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_);
@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
TraceManager::EndTrace();
}
void Cloner::CloneValueNode(const AnfNodePtr& node) {
void Cloner::CloneValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
TraceManager::DebugTrace(node->debug_info(), relation_);
ValueNodePtr new_const = NewValueNode(GetValueNode(node));
@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) {
TraceManager::EndTrace();
}
void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) {
void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(target);
TraceManager::DebugTrace(node->debug_info(), relation_);
@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target)
TraceManager::EndTrace();
}
void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) {
void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(manager_);
if (!clone_all_valuenodes_) {
return;
}
auto& value_nodes = manager_->valuenodes()[func_graph];
for (auto& value_node : value_nodes) {
auto &value_nodes = manager_->valuenodes()[func_graph];
for (auto &value_node : value_nodes) {
auto old_node = value_node.first;
MS_EXCEPTION_IF_NULL(old_node);
if (repl_node_.count(old_node) == 0) {
@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) {
}
}
void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) {
void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(manager_);
if (!clone_all_child_graphs_) {
return;
}
auto& scopes = manager_->scopes(func_graph);
for (auto& graph : scopes) {
auto &scopes = manager_->scopes(func_graph);
for (auto &graph : scopes) {
if (graph != func_graph) {
todo_.push_back({graph, nullptr, {}});
}
}
}
void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) {
void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(manager_);
if (!clone_all_used_graphs_) {
return;
}
auto& used_graphs = manager_->func_graphs_used()[func_graph];
for (auto& used_graph : used_graphs) {
auto &used_graphs = manager_->func_graphs_used()[func_graph];
for (auto &used_graph : used_graphs) {
todo_.push_back({used_graph.first, nullptr, {}});
}
}
void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
for (auto& item : func_graph->parameter_default_value()) {
for (auto &item : func_graph->parameter_default_value()) {
auto nodes = DeepLinkedGraphSearch(item.second);
for (auto& node : nodes) {
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
CloneNode(node, target_func_graph);
@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F
}
}
void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_);
@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func
}
target_func_graph->set_return(return_node);
auto& value_nodes = manager_->func_graph_valuenodes()[func_graph];
for (auto& value_node : value_nodes) {
auto &value_nodes = manager_->func_graph_valuenodes()[func_graph];
for (auto &value_node : value_nodes) {
CloneValueNode(value_node.first, target_func_graph);
}
}
void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) {
void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params) {
MS_EXCEPTION_IF_NULL(func_graph);
auto& old_params = func_graph->parameters();
auto &old_params = func_graph->parameters();
if (old_params.size() != params.size()) {
MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]";
return;
@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode
}
}
void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) {
void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons
TraceManager::EndTrace();
}
void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
auto& params = func_graph->parameters();
for (auto& param : params) {
auto &params = func_graph->parameters();
for (auto &param : params) {
CloneParameter(param, target_func_graph, true);
}
repl_func_graph_[func_graph] = target_func_graph;
}
void Cloner::GenParameters(const FuncGraphPtr& func_graph) {
void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto& free_vars = manager_->free_variables_total();
auto &free_vars = manager_->free_variables_total();
auto iter = free_vars.find(func_graph);
if (iter == free_vars.end()) {
return;
}
for (auto& fv_map : iter->second) {
auto& free_var = fv_map.first;
for (auto &fv_map : iter->second) {
auto &free_var = fv_map.first;
if (utils::isa<AnfNodePtr>(free_var)) {
repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
}
}
}
void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) {
void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
param->set_abstract(node->abstract());
if (node->isa<Parameter>()) {
ParameterPtr old_param = dyn_cast<Parameter>(node);
@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) {
}
}
ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) {
ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(func_graph);
TraceManager::EndTrace();
@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP
return param;
}
void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params,
AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) {
void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params,
AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
AnfNodePtrList parameters;
std::unordered_set<AnfNodePtr> old_params;
for (auto& param : func_graph->parameters()) {
for (auto &param : func_graph->parameters()) {
auto iter = repl_node_.find(param);
if (iter != repl_node_.end()) {
(void)old_params.insert(iter->second);
@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList&
}
}
AnfNodePtr new_param = nullptr;
for (auto& param : params) {
for (auto &param : params) {
auto old_param = repl_node_[param];
if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
repl_node_[old_param] = old_param;
@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList&
func_graph->set_parameters(parameters);
}
void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph,
const AnfNodePtrList& params) {
void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList &params) {
AnfNodePtr node = nullptr;
auto& repl_func_graph = repl_map_func_graph_[func_graph_user];
auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
auto iter = repl_func_graph.find(func_graph);
if (iter == repl_func_graph.end()) {
node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr&
OrderParameters(func_graph, inputs);
}
void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) {
void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
std::unordered_set<AnfNodePtr> old_params;
for (auto& param : func_graph->parameters()) {
for (auto &param : func_graph->parameters()) {
(void)old_params.insert(repl_node_[param]);
}
std::unordered_set<AnfNodePtr> new_params;
@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis
(void)new_params.insert(new_param);
}
}
for (auto& param : func_graph->parameters()) {
for (auto &param : func_graph->parameters()) {
if (new_params.find(param) == new_params.end()) {
parameters.push_back(param);
}
@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis
func_graph->set_parameters(parameters);
}
void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
void Cloner::SetEdges(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
for (auto& node : func_graph->nodes()) {
for (auto &node : func_graph->nodes()) {
if (node == nullptr) {
continue;
}
@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto& inputs = cnode->inputs();
auto &inputs = cnode->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
auto& input = inputs[i];
auto &input = inputs[i];
if (IsValueNode<FuncGraph>(input)) {
auto graph = GetValueNode<FuncGraphPtr>(input);
auto& repl_func_graph = repl_map_func_graph_[func_graph];
auto &repl_func_graph = repl_map_func_graph_[func_graph];
if (repl_func_graph.find(graph) != repl_func_graph.end()) {
transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
}
} else {
auto& repl_node = repl_map_node_[func_graph];
auto &repl_node = repl_map_node_[func_graph];
if (repl_node.find(input) != repl_node.end()) {
transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]);
}
@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) {
}
}
void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph,
const AnfNodePtrList& params) {
void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList &params) {
AnfNodePtrList lift_params;
AnfNodePtrList input_params;
AddParameters(func_graph_user, params, &lift_params, &input_params);
@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph
if (lift_params.empty()) {
return;
}
for (auto& user : func_graph_user->func_graph_users()) {
for (auto &user : func_graph_user->func_graph_users()) {
LiftParameters(user.first, func_graph_user, lift_params);
}
}
void Cloner::Lift() {
for (auto& func_graph_params : repl_func_graph_params_) {
auto& func_graph = func_graph_params.first;
auto& params = func_graph_params.second;
for (auto& user : func_graph->func_graph_users()) {
for (auto &func_graph_params : repl_func_graph_params_) {
auto &func_graph = func_graph_params.first;
auto &params = func_graph_params.second;
for (auto &user : func_graph->func_graph_users()) {
LiftParameters(user.first, func_graph, params);
}
}
@ -404,18 +404,18 @@ void Cloner::Lift() {
void Cloner::LiftParameters() {
MS_EXCEPTION_IF_NULL(manager_);
transaction_ = manager_->Transact();
const FuncGraphSet& func_graphs = manager_->func_graphs();
for (auto& func_graph : func_graphs) {
const FuncGraphSet &func_graphs = manager_->func_graphs();
for (auto &func_graph : func_graphs) {
GenParameters(func_graph);
}
Lift();
for (auto& func_graph : func_graphs) {
for (auto &func_graph : func_graphs) {
SetEdges(func_graph);
}
transaction_.Commit();
}
bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) {
bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
MS_EXCEPTION_IF_NULL(func_graph);
// Make sure only inline once
if (status_.count(func_graph) != 0) {
@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) {
return true;
}
void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) {
void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_);
const AnfNodeSet& nodes = manager_->nodes()[func_graph];
for (auto& node : nodes) {
const AnfNodeSet &nodes = manager_->nodes()[func_graph];
for (auto &node : nodes) {
CloneNode(node, target_func_graph);
}
}
@ -449,7 +449,7 @@ void Cloner::Run() {
// Basic and Inline Clone
FuncGraphPtrList func_graphs;
(void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
[](const CloneInfo& item) -> FuncGraphPtr { return item.origin; });
[](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
manager_ = Manage(func_graphs, false);
CloneNodes();
LinkEdges();
@ -495,13 +495,13 @@ void Cloner::CloneNodes() {
}
void Cloner::LinkEdges() {
for (auto& node_pair : nodes_) {
for (auto &node_pair : nodes_) {
CNodePtr old_node = node_pair.first;
CNodePtr new_node = node_pair.second;
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node);
for (auto& input : old_node->inputs()) {
auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
for (auto &input : old_node->inputs()) {
auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
new_node->add_input(new_input);
}
}
@ -509,10 +509,10 @@ void Cloner::LinkEdges() {
// For the graphs cloned, update its default value map to the cloned nodes
void Cloner::SetDefaults() {
for (auto& item : graph_set_) {
for (auto &item : graph_set_) {
MS_EXCEPTION_IF_NULL(item);
if (repl_func_graph_.count(item) != 0) {
for (auto& param_def : item->parameter_default_value()) {
for (auto &param_def : item->parameter_default_value()) {
MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
if (repl_node_.count(param_def.second) != 0) {
repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
@ -524,7 +524,7 @@ void Cloner::SetDefaults() {
}
}
AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) {
AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
MS_EXCEPTION_IF_NULL(root);
if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) {
MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
}
AnfNodePtr Cloner::operator[](const AnfNodePtr& node) {
AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
#ifdef ENABLE_PROFILE
double time = GetTime();
#endif
@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) {
return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
}
FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) {
FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
#ifdef ENABLE_PROFILE
double time = GetTime();
#endif
@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) {
return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
}
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) {
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
return cloner[func_graph];
}
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph,
const AnfNodePtrList& func_graph_args, const ScopePtr& scope) {
AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
Cloner cloner({}, false);
@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe
return cloner[func_graph->output()];
}
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) {
FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
Cloner cloner({}, false);
cloner.AddClone(func_graph, nullptr, {}, kLifting);
return cloner[func_graph];
}
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) {
ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphPtrList func_graphs = {func_graph};
ClonerPtr cloner =
@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r
return cloner;
}
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) {
FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
MS_EXCEPTION_IF_NULL(func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), relation);
auto new_func_graph = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
auto& parameters = func_graph->parameters();
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void {
auto &parameters = func_graph->parameters();
(void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
MS_EXCEPTION_IF_NULL(param);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info()));
(void)new_func_graph->add_parameter();
@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated());
for (auto& item : func_graph->parameter_default_value()) {
for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
}

View File

@ -43,26 +43,26 @@ struct CloneInfo {
class Cloner {
public:
explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false,
explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false,
bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
const TraceInfoPtr& relation = std::make_shared<TraceCopy>(),
const TraceInfoPtr& target_relation = nullptr);
const TraceInfoPtr &relation = std::make_shared<TraceCopy>(),
const TraceInfoPtr &target_relation = nullptr);
~Cloner() = default;
void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr,
const AnfNodePtrList& params = {}, CloneType type = kBasic);
void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr,
const AnfNodePtrList &params = {}, CloneType type = kBasic);
void Run();
// Interfaces for specializer
AnfNodePtr CloneDisconnected(const AnfNodePtr& root);
AnfNodePtr operator[](const AnfNodePtr& node);
FuncGraphPtr operator[](const FuncGraphPtr& func_graph);
AnfNodePtr CloneDisconnected(const AnfNodePtr &root);
AnfNodePtr operator[](const AnfNodePtr &node);
FuncGraphPtr operator[](const FuncGraphPtr &func_graph);
// Map of replicate nodes and graphs
std::unordered_map<AnfNodePtr, AnfNodePtr>* cloned_node() { return &repl_node_; }
std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; }
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; }
// Scope of cloned graphs
void set_scope(const ScopePtr& scope) { scope_ = scope; }
void set_scope(const ScopePtr &scope) { scope_ = scope; }
const ScopePtr scope() const { return scope_; }
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_;
@ -71,31 +71,31 @@ class Cloner {
void CloneNodes();
void LinkEdges();
void SetDefaults();
void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target);
void CloneValueNode(const AnfNodePtr& node);
void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target);
void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target);
void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false);
void CloneValueNodes(const FuncGraphPtr& func_graph);
void AddChildGraphs(const FuncGraphPtr& func_graph);
void AddTotalGraphs(const FuncGraphPtr& func_graph);
bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline);
void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params);
void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph);
void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void GenParameters(const FuncGraphPtr& func_graph);
void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node);
ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true);
void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params,
AnfNodePtrList* const input_params);
void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params);
void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs);
void SetEdges(const FuncGraphPtr& func_graph);
void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph,
const AnfNodePtrList& params);
void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneValueNode(const AnfNodePtr &node);
void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
void CloneValueNodes(const FuncGraphPtr &func_graph);
void AddChildGraphs(const FuncGraphPtr &func_graph);
void AddTotalGraphs(const FuncGraphPtr &func_graph);
bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline);
void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph);
void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void GenParameters(const FuncGraphPtr &func_graph);
void CloneParameter(const ParameterPtr &param, const AnfNodePtr &node);
ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true);
void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params, AnfNodePtrList *const lift_params,
AnfNodePtrList *const input_params);
void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs);
void SetEdges(const FuncGraphPtr &func_graph);
void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList &params);
void Lift();
void LiftParameters();
@ -118,17 +118,17 @@ class Cloner {
std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_;
};
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph);
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph);
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph,
const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr);
AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr);
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph);
FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph);
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation);
ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation);
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph,
const TraceInfoPtr& relation = std::make_shared<TraceTransform>());
FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_

View File

@ -27,17 +27,17 @@
namespace mindspore {
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs, bool manage) {
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
auto m = std::make_shared<FuncGraphManager>(func_graphs, manage);
m->Init();
return m;
}
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage) {
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
FuncGraphManagerPtr m = nullptr;
bool root = false;
for (auto& fg : func_graphs) {
for (auto &fg : func_graphs) {
if (fg == nullptr) {
continue;
}
@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool ma
root = true;
}
for (auto& fg : func_graphs) {
for (auto &fg : func_graphs) {
if (fg == nullptr) {
continue;
}
@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
return Manage(func_graphs, manage);
}
FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage)
FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
: roots_(roots), is_manage_(manage) {
Reset();
}
@ -103,12 +103,12 @@ void FuncGraphManager::Init() {
auto roots = roots_;
roots_ = FuncGraphSet();
for (auto& fg : roots) {
for (auto &fg : roots) {
AddFuncGraph(fg, true);
}
}
FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const {
FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
func_graph_parents_total_->Recompute(fg);
@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg)
return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
}
FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const {
FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(func_graph_parent_);
MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const {
return func_graph_parent_->parent_analysis()[fg];
}
FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const {
FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(children_);
MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const {
return children_->children_analysis()[fg];
}
FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const {
FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(scopes_);
MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString();
@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const {
return scopes_->scope_analysis()[fg];
}
FVTotalMap& FuncGraphManager::free_variables_total() const {
FVTotalMap &FuncGraphManager::free_variables_total() const {
MS_EXCEPTION_IF_NULL(free_variables_total_);
free_variables_total_->Recompute();
return free_variables_total_->fv_total_analysis();
}
FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const {
FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
func_graphs_used_total_->Recompute(fg);
return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
}
bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const {
bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
recursive_->Recompute(fg);
if (recursive_->recursive_analysis().count(fg) == 0) {
@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const {
return recursive_->recursive_analysis()[fg];
}
std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const {
std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
if (recursive(fg)) {
if (!recursive_->recursive_map().count(fg)) {
@ -185,7 +185,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons
}
}
bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const {
bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(j_total_);
MS_EXCEPTION_IF_NULL(fg);
j_total_->Recompute(fg);
@ -225,10 +225,10 @@ void FuncGraphManager::Clear() {
signals_->InvalidateComputer();
}
void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
MS_LOG(DEBUG) << "Start keep roots";
bool root_exist = false;
for (auto& item : func_graphs) {
for (auto &item : func_graphs) {
if (roots_.contains(item)) {
root_exist = true;
break;
@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
roots = roots_;
} else {
roots_.clear();
for (auto& item : roots) {
for (auto &item : roots) {
AddFuncGraph(item, true);
}
}
FuncGraphSet keep;
for (auto& item : roots) {
for (auto &item : roots) {
MS_LOG(DEBUG) << "roots: " << item->ToString();
keep.update(func_graphs_used_total(item));
#ifdef DEBUG
for (auto& k : keep) {
for (auto &k : keep) {
MS_LOG(DEBUG) << "keep: " << k->ToString();
}
#endif
@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) {
} else {
Clear();
FuncGraphSet roots(func_graphs);
for (auto& item : roots) {
for (auto &item : roots) {
AddFuncGraph(item, true);
}
}
@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() {
MaybeDropFuncGraphs(func_graphs_, true);
}
void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) {
void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(fg);
if (is_manage_) {
if (fg->manager() != nullptr && (&(*fg->manager()) != this)) {
@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) {
func_graphs_.add(fg);
}
void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) {
void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
FuncGraphSet todo(func_graphs);
std::set<FuncGraphPtr> dropped;
// int count = 0;
@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool
continue;
}
MS_EXCEPTION_IF_NULL(func_graph_users_);
auto& users = func_graph_users_->count_func_graphs_map()[func_graph];
auto &users = func_graph_users_->count_func_graphs_map()[func_graph];
if (!users.empty() && !ignore_users) {
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
continue;
@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool
todo.update(MaybeDropNodes(return_vec));
}
MS_EXCEPTION_IF_NULL(signals_);
for (auto& fg : dropped) {
for (auto &fg : dropped) {
MS_EXCEPTION_IF_NULL(fg);
signals_->DropFuncGraph(fg);
all_nodes_.difference_update(fg->parameters());
@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
MS_EXCEPTION_IF_NULL(inp);
if (direction == kDecEdge) {
MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString();
auto& users_node = node_users_[inp];
auto &users_node = node_users_[inp];
if (!users_node.contains(make_pair(node, index))) {
return;
}
@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString();
AddFuncGraph(GetValueNode<FuncGraphPtr>(inp));
}
auto& users_node = node_users_[inp];
auto &users_node = node_users_[inp];
users_node.add(make_pair(node, index));
MS_EXCEPTION_IF_NULL(signals_);
signals_->AddEdge(node, index, inp);
}
}
void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) {
void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
int index = 0;
for (auto& inp : cnode->inputs()) {
for (auto &inp : cnode->inputs()) {
ProcessEdge(cnode, index, inp, direction);
++index;
}
}
}
IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) {
IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) {
if (all_nodes_.contains(node)) {
return EXCLUDE;
} else {
@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) {
}
}
void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
AnfNodeSet acq;
for (auto& node : nodes) {
for (auto &node : nodes) {
std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1);
AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit));
@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
acq.update(new_nodes);
}
for (auto& node : acq) {
for (auto &node : acq) {
MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg = node->func_graph();
if (fg != nullptr) {
@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) {
}
}
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& nodes) {
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
AnfNodeSet nodes_ordered(nodes);
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
MS_EXCEPTION_IF_NULL(signals_);
@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>&
if (!all_nodes_.contains(node)) {
continue;
}
AnfNodeIndexSet& users = node_users_[node];
AnfNodeIndexSet &users = node_users_[node];
std::vector<AnfNodePtr> parameters;
if (!users.empty() ||
@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>&
return func_graphs_to_check;
}
void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters) {
void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters) {
auto tr = Transact();
tr.SetParameters(fg, parameters);
tr.Commit();
}
bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) {
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
auto tr = Transact();
bool success = tr.Replace(old_node, new_node);
if (success) {
@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new
return success;
}
void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) {
void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
auto tr = Transact();
tr.SetEdge(node, index, value);
tr.Commit();
}
void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) {
void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) {
AnfNodePtr source_return = source->get_return();
AnfNodePtr source_output = source->output();
AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
(void)all_nodes_.erase(source_return);
(void)node_users_.erase(source_return);
signals_->DropNode(source_return);
for (auto& node : source->nodes()) {
for (auto &node : source->nodes()) {
node->set_func_graph(target);
if (node->scope() == kDefaultScope) {
node->set_scope(scope);
}
}
for (auto& used : source->func_graphs_used()) {
for (auto &used : source->func_graphs_used()) {
(void)func_graph_users_->Inc(used.first, target, used.second);
(void)this->func_graph_users()[used.first].erase(source);
}
for (auto& child : this->func_graph_child_direct()[source]) {
for (auto &child : this->func_graph_child_direct()[source]) {
(void)func_graph_parents_direct_->Inc(child.first, target, child.second);
(void)this->func_graph_parents_direct()[child.first].erase(source);
}
for (auto& fv_count : this->free_variables_direct()[source]) {
for (auto &fv_count : this->free_variables_direct()[source]) {
auto fv_g = fv_count.first->func_graph();
auto& count_on_g = this->func_graph_child_direct()[fv_g];
auto &count_on_g = this->func_graph_child_direct()[fv_g];
auto pair = count_on_g.find(source);
if (fv_g != target && pair != count_on_g.end()) {
(void)func_graph_child_direct_->Inc(fv_g, target, pair->second);
@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() {
return tr;
}
void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges,
EdgeTupleCounter* rm_edges, Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms) {
for (auto& iter : changes) {
void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges,
EdgeTupleCounter *rm_edges, Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms) {
for (auto &iter : changes) {
auto operation = iter.op;
auto args = iter.args;
if (operation == Change::kTxSetEdge) {
@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl
auto param = args.cast<ArgsOfSetParams>();
MS_EXCEPTION_IF_NULL(param.func_graph);
auto old_parameters = param.func_graph->parameters();
for (auto& p : param.params) {
for (auto &p : param.params) {
(*adds)[p] += 1;
}
for (auto& p : old_parameters) {
for (auto &p : old_parameters) {
(*rms)[p] += 1;
}
param.func_graph->set_parameters(param.params);
@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl
}
}
void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) {
EdgeTupleCounter add_edges;
EdgeTupleCounter rm_edges;
Counter<AnfNodePtr> adds;
@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms);
auto sub_edges = add_edges - rm_edges;
for (auto& iter : sub_edges) {
for (auto &iter : sub_edges) {
auto root_node = iter.first.first;
int index = iter.first.second.first;
auto new_node = iter.first.second.second;
@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
auto sub_nodes = adds - rms;
std::vector<AnfNodePtr> nodes;
(void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes),
[](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; });
[](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
AcquireNodes(nodes);
auto sub_edges_reverse = rm_edges - add_edges;
for (auto& iter : sub_edges_reverse) {
for (auto &iter : sub_edges_reverse) {
auto root_node = iter.first.first;
int index = iter.first.second.first;
auto old_node = iter.first.second.second;
@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) {
std::vector<AnfNodePtr> nodes_reverse;
(void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse),
[](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; });
[](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; });
auto drop_func_graphs = MaybeDropNodes(nodes_reverse);
MaybeDropFuncGraphs(*drop_func_graphs);
}
void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params) {
void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params) {
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
}
bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) {
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node);
FuncGraphPtr old_func_graph = old_node->func_graph();
@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr&
return false;
}
auto users = manager_->node_users()[old_node];
for (auto& node : users) {
for (auto &node : users) {
SetEdge(node.first, node.second, new_node);
}
return true;
}
void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) {
void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
if (k < 0) {
MS_LOG(EXCEPTION) << "Invalid value k = " << k;
}
@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() {
manager_->CommitChanges(changes);
}
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager)
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
: manager_(manager), include_func_graph_none_(false) {
manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager)
manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
}
NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() {
NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() {
include_func_graph_none_ = true;
nodes_analysis_[nullptr] = AnfNodeSet();
@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) {
void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// change the owner of node except for the src's return node
for (auto& it : nodes_analysis_[src]) {
for (auto &it : nodes_analysis_[src]) {
nodes_analysis_[dst].add(it);
}
(void)nodes_analysis_.erase(src);
@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) {
DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
}
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) {
auto& d = count_nodes_map_[func_graph];
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
auto &d = count_nodes_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP
return false;
}
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) {
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
MS_EXCEPTION_IF_NULL(func_graph);
auto& d = count_nodes_map_[func_graph];
auto &d = count_nodes_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP
return false;
}
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) {
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP
}
}
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) {
auto& d = count_func_graphs_map_[func_graph];
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr
return false;
}
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) {
auto& d = count_func_graphs_map_[func_graph];
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr
return false;
}
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) {
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr
}
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_nodes_map_.erase(src);
@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed
}
void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_nodes_map_.erase(src);
@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc
}
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) {
for (auto &it : count_nodes_map_[src]) {
FuncGraphPtr fg2 = it.first->func_graph();
if (fg2 != dst) {
(void)Inc(dst, it.first, it.second);
@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
(void)count_nodes_map_.erase(src);
}
static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) {
static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
FuncGraphPtr gn = std::make_shared<FuncGraph>();
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
return gn;
@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP
}
void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_func_graphs_map_[src]) {
for (auto &it : count_func_graphs_map_[src]) {
FuncGraphPtr fg = it.first;
if (fg != dst) {
(void)Inc(dst, fg, it.second);
@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr
}
void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_func_graphs_map_[src]) {
for (auto &it : count_func_graphs_map_[src]) {
if (it.first != dst) {
(void)Inc(dst, it.first, it.second);
}
@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed
void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use
for (auto& it : count_func_graphs_map_[src]) {
for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_func_graphs_map_[dst].erase(src);
@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp
}
void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto& it : count_nodes_map_[src]) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_nodes_map_.erase(src);
@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp,
void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use
for (auto& it : count_func_graphs_map_[src]) {
for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_func_graphs_map_.erase(src);
}
DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) {
DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
validate_ = false;
@ -914,20 +914,20 @@ void DepComputer::Recompute() {
}
}
void DepComputer::Recompute(const FuncGraphPtr& fg) {
void DepComputer::Recompute(const FuncGraphPtr &fg) {
if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
RealRecompute(fg);
func_graphs_validate_[fg] = true;
}
}
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
if (path == nullptr || path->contains(fg)) {
return std::make_shared<FuncGraphSet>();
}
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_;
for (auto& dep : deps[fg]) {
FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_;
for (auto &dep : deps[fg]) {
MS_EXCEPTION_IF_NULL(dep.first);
auto proxy = dep.first->transforms().find("proxy");
if (proxy != dep.first->transforms().end()) {
@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size();
}
bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) {
bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
auto l1 = lhs.second.size();
auto l2 = rhs.second.size();
return l1 < l2;
@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
} else {
// return nearest parent as parent
FuncGraphSet deps_copy(deps);
for (auto& dep : deps) {
for (auto &dep : deps) {
auto parent_deps = this->manager_->func_graph_parents_total(dep);
for (auto& p_d : parent_deps) {
for (auto &p_d : parent_deps) {
if (deps_copy.count(p_d)) {
(void)deps_copy.erase(p_d);
}
@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_);
auto used_fg_total = manager_->func_graphs_used_total(fg);
for (auto& used_fg : used_fg_total) {
for (auto &used_fg : used_fg_total) {
if (manager_->parent(used_fg) == fg) {
children_analysis_[fg].add(used_fg);
}
@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_);
auto& children = manager_->children(fg);
auto &children = manager_->children(fg);
scope_analysis_[fg] = FuncGraphSet();
scope_analysis_[fg].add(fg);
for (auto& child : children) {
for (auto &child : children) {
scope_analysis_[fg].add(child);
}
}
@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() {
auto manager = DepComputer::manager_;
MS_EXCEPTION_IF_NULL(manager);
for (auto& fg : manager->func_graphs()) {
for (auto &fg : manager->func_graphs()) {
fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
}
for (auto& fg : manager->func_graphs()) {
for (auto &fg : manager->func_graphs()) {
AnfNodeCounterMap items = manager->free_variables_direct()[fg];
for (auto& iter : items) {
for (auto &iter : items) {
auto curr = fg;
while (curr) {
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
curr = manager->parent(curr);
const AnfNodeSet& nodes = manager->nodes()[curr];
const AnfNodeSet &nodes = manager->nodes()[curr];
if (nodes.contains(iter.first)) {
break;
}
@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() {
}
auto items_fg = manager->func_graphs_used()[fg];
for (auto& iter : items_fg) {
for (auto &iter : items_fg) {
auto p = manager->parent(iter.first);
if (p == nullptr) {
continue;
@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() {
}
}
}
for (auto& fg : manager->func_graphs()) {
auto& fvp = count_nodes_map_[fg];
auto& fvg = count_func_graphs_map_[fg];
for (auto& item : fvp) {
for (auto &fg : manager->func_graphs()) {
auto &fvp = count_nodes_map_[fg];
auto &fvg = count_func_graphs_map_[fg];
for (auto &item : fvp) {
fv_total_analysis_[fg][item.first] = item.second;
}
for (auto& item : fvg) {
for (auto &item : fvg) {
fv_total_analysis_[fg][item.first] = item.second;
}
}
@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() {
void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_);
auto& used = this->manager_->func_graphs_used();
auto &used = this->manager_->func_graphs_used();
std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new;
todo.push_back(fg);
while (!todo.empty()) {
todo_new.clear();
for (auto& gt : todo) {
for (auto& item : used[gt]) {
for (auto &gt : todo) {
for (auto &item : used[gt]) {
auto used_fg = item.first;
if (used_fg == fg) {
func_graph_used_total_analysis_[fg].add(used_fg);
@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
}
}
bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) {
bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(manager);
auto& used = manager->func_graphs_used();
auto &used = manager->func_graphs_used();
std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new;
todo.push_back(fg);
FuncGraphSet used_total;
while (!todo.empty()) {
todo_new.clear();
for (auto& gt : todo) {
for (auto& item : used[gt]) {
for (auto &gt : todo) {
for (auto &item : used[gt]) {
auto used_g = item.first;
if (used_g == fg) {
return true;
@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
}
void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace) {
void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
MS_EXCEPTION_IF_NULL(trace);
auto res = std::find(trace->begin(), trace->end(), fg);
// find recursive
@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F
}
} else {
trace->push_back(fg);
auto& used_fgs = manager_->func_graphs_used()[fg];
auto &used_fgs = manager_->func_graphs_used()[fg];
for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) {
CheckRecursiveGraphs(iter->first, trace);
}
@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F
}
}
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
MS_EXCEPTION_IF_NULL(path);
if (path->contains(fg)) {
MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked";
return false;
}
MS_EXCEPTION_IF_NULL(manager_);
auto& func_graph_counter_map = manager_->func_graph_j_direct();
auto &func_graph_counter_map = manager_->func_graph_j_direct();
if (!func_graph_counter_map[fg].empty()) {
// check g1->J(fg)->g2->g cycle;
auto contains_j =
@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt
path->add(fg);
// check if func graphs used contains J(func_graph);
auto& used = this->manager_->func_graphs_used();
for (auto& item : used[fg]) {
auto &used = this->manager_->func_graphs_used();
for (auto &item : used[fg]) {
auto used_g = item.first;
if (SeekJ(used_g, path)) {
MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString()

View File

@ -46,13 +46,13 @@ class FuncGraphManager;
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
struct AnfNodeIndexPairHasher {
std::size_t operator()(const std::pair<AnfNodePtr, int>& p1) const {
return std::hash<const AnfNode*>{}(p1.first.get());
std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const {
return std::hash<const AnfNode *>{}(p1.first.get());
}
};
struct AnfNodeIndexPairEqual {
bool operator()(const std::pair<AnfNodePtr, int>& lhs, const std::pair<AnfNodePtr, int>& rhs) const {
bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const {
return lhs == rhs;
}
};
@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>;
struct EdgeTupleHasher {
std::size_t operator()(const EdgeTuple& p1) const {
return hash_combine({std::hash<AnfNode*>{}(p1.first.get()), std::hash<int>{}(p1.second.first),
std::hash<AnfNode*>{}(p1.second.second.get())});
std::size_t operator()(const EdgeTuple &p1) const {
return hash_combine({std::hash<AnfNode *>{}(p1.first.get()), std::hash<int>{}(p1.second.first),
std::hash<AnfNode *>{}(p1.second.second.get())});
}
};
struct EdgeTupleEqual {
bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const {
bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const {
return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second;
}
};
@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter<EdgeTuple, EdgeTupleHasher, EdgeTupleEqual>;
// FuncGraphManagerPtr: return created manager
FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true);
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage = true);
FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true);
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs = {}, bool manage = true);
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);
struct Signals {
Signal<void(FuncGraphPtr)> AddFuncGraph;
@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNode
// analysis base class
class FuncGraphAnalysis {
public:
explicit FuncGraphAnalysis(const FuncGraphManager* const manager);
explicit FuncGraphAnalysis(const FuncGraphManager *const manager);
virtual ~FuncGraphAnalysis() { manager_ = nullptr; }
@ -130,7 +130,7 @@ class FuncGraphAnalysis {
virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {}
const FuncGraphManager* manager_;
const FuncGraphManager *manager_;
bool include_func_graph_none_;
};
@ -139,7 +139,7 @@ using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
// graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis {
public:
explicit DepCollector(const FuncGraphManager* manager);
explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default;
void Reset() { ExtraReset(); }
@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis {
class NodesCollector final : public DepCollector {
public:
explicit NodesCollector(const FuncGraphManager* m);
explicit NodesCollector(const FuncGraphManager *m);
~NodesCollector() override = default;
const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; }
const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; }
size_t size() const override { return nodes_analysis_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }
@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector {
class CounterFuncGraphCollector : public DepCollector {
public:
explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {}
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterFuncGraphCollector() override = default;
FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
// inherit from FuncGraphAnalysis
size_t size() const override { return count_func_graphs_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count);
bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count);
bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count);
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector {
class CounterAnfNodeCollector : public DepCollector {
public:
explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {}
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; }
FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; }
size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count);
bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count);
bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count);
bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
FuncGraphToAnfNodeCounterMap count_nodes_map_;
@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector {
class ValueNodesCollector final : public CounterAnfNodeCollector {
public:
explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~ValueNodesCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector {
class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
public:
explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FuncGraphValueNodesCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
class FVDirectCollector final : public CounterAnfNodeCollector {
public:
explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FVDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector {
class FuncGraphChildDirect final : public CounterFuncGraphCollector {
public:
explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphChildDirect() override = default;
@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector {
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
~FuncGraphParentsDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
// graph's all used graphs: key is g, value is g used graph
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphsUsedCollector() override = default;
@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
// graph's all user graphs: key is g, value is graphs who used g
class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUsersCollector() override = default;
@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
// graph's all user cnodes: key is g, value is cnodes who used g
class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
public:
explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {}
explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUserNodesCollector() override = default;
@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {}
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
~FuncGraphJDirectCollector() override = default;
@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
// graphs analysis which need dynamic compute by DepCollector in each read
class DepComputer : public FuncGraphAnalysis {
public:
explicit DepComputer(const FuncGraphManager* manager);
explicit DepComputer(const FuncGraphManager *manager);
~DepComputer() override = default;
void Reset() {
@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis {
void Recompute();
void Recompute(const FuncGraphPtr& fg);
void Recompute(const FuncGraphPtr &fg);
bool IsValidate() const { return validate_; }
bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; }
bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
void OnAddFuncGraph(FuncGraphPtr) final { Reset(); }
@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis {
// graph g's all direct or proxy parents
class FuncGraphParentsTotalComputer final : public DepComputer {
public:
explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {}
explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {}
~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; }
FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
size_t size() const override { return func_graph_parents_total_analysis_.size(); }
@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
void RealRecompute(FuncGraphPtr fg) override;
private:
FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>());
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>());
// when SeekParents calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 1 parameter for SeekParents().
FuncGraphToFuncGraphCounterMap* all_parents_direct_;
FuncGraphToFuncGraphCounterMap *all_parents_direct_;
};
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
// graph's nearest parent in parents total
class ParentComputer final : public DepComputer {
public:
explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {}
explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {}
~ParentComputer() override = default;
FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; }
FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; }
size_t size() const override { return parent_analysis_.size(); }
@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer {
// graph's children graph except self
class ChildrenComputer final : public DepComputer {
public:
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {}
explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {}
~ChildrenComputer() override = default;
FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; }
FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; }
size_t size() const override { return children_analysis_.size(); }
@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer {
// graph's children graph include self
class ScopeComputer final : public DepComputer {
public:
explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {}
explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {}
~ScopeComputer() override = default;
FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; }
FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; }
size_t size() const override { return scope_analysis_.size(); }
@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash
class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector {
public:
explicit FVTotalComputer(const FuncGraphManager* m)
explicit FVTotalComputer(const FuncGraphManager *m)
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
~FVTotalComputer() override = default;
FVTotalMap& fv_total_analysis() { return fv_total_analysis_; }
FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
size_t size() const override { return fv_total_analysis_.size(); }
@ -453,10 +453,10 @@ class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector
class FuncGraphsUsedTotalComputer final : public DepComputer {
public:
explicit FuncGraphsUsedTotalComputer(const FuncGraphManager* m) : DepComputer(m) {}
explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FuncGraphsUsedTotalComputer() override = default;
FuncGraphToFuncGraphSetMap& func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
size_t size() const override { return func_graph_used_total_analysis_.size(); }
@ -473,13 +473,13 @@ using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGrap
class RecursiveComputer final : public DepComputer {
public:
explicit RecursiveComputer(const FuncGraphManager* m) : DepComputer(m) {}
explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {}
~RecursiveComputer() override = default;
RecursiveMap& recursive_map() { return recursive_map_; }
FuncGraphToBoolMap& recursive_analysis() { return recursive_analysis_; }
RecursiveMap &recursive_map() { return recursive_map_; }
FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; }
void CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace);
void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace);
size_t size() const override { return recursive_analysis_.size(); }
@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer {
class FuncGraphJTotalComputer final : public DepComputer {
public:
explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {}
explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FuncGraphJTotalComputer() override = default;
FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; }
FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; }
size_t size() const override { return j_total_analysis_.size(); }
@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer {
void ExtraReset() override { j_total_analysis_.clear(); }
void RealRecompute(FuncGraphPtr fg) override;
bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path);
bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path);
};
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
public:
explicit FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage = true);
explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
~FuncGraphManager() {
if (is_manage_) {
RemoveRoots();
@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
void Init();
void Clear();
void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false);
void KeepRoots(const std::vector<FuncGraphPtr>& roots = {});
void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
void RemoveRoots();
void SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters);
void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node);
void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value);
void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope);
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope);
FuncGraphTransaction Transact();
void CommitChanges(const std::vector<Change>& changes);
void CommitChanges(const std::vector<Change> &changes);
bool IsManaged() const { return is_manage_; }
const FuncGraphSet& roots() const { return roots_; }
const FuncGraphSet &roots() const { return roots_; }
const FuncGraphSet& func_graphs() const { return func_graphs_; }
const FuncGraphSet &func_graphs() const { return func_graphs_; }
AnfNodeSet& all_nodes() { return all_nodes_; }
AnfNodeSet &all_nodes() { return all_nodes_; }
NodeUsersMap& node_users() { return node_users_; }
NodeUsersMap &node_users() { return node_users_; }
FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; }
FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }
FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; }
FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const {
FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
return func_graph_child_direct_->count_func_graphs_map_;
}
FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const {
FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const {
return func_graph_parents_direct_->count_func_graphs_map_;
}
FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }
FVTotalMap& free_variables_total() const;
FVTotalMap &free_variables_total() const;
FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const;
FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
FuncGraphSet& scopes(const FuncGraphPtr& fg) const;
FuncGraphSet &scopes(const FuncGraphPtr &fg) const;
FuncGraphPtr parent(const FuncGraphPtr& fg) const;
FuncGraphPtr parent(const FuncGraphPtr &fg) const;
FuncGraphSet& children(const FuncGraphPtr& fg) const;
FuncGraphSet &children(const FuncGraphPtr &fg) const;
FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const;
FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
bool recursive(const FuncGraphPtr& fg) const;
std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr& fg) const;
bool recursive(const FuncGraphPtr &fg) const;
std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
bool func_graph_j_total(const FuncGraphPtr& fg) const;
bool func_graph_j_total(const FuncGraphPtr &fg) const;
std::shared_ptr<Signals> signals() const { return signals_; }
IncludeType Limit(const AnfNodePtr& node);
IncludeType Limit(const AnfNodePtr &node);
// Static Analysis
NodeUsersMap node_users_;
@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
std::shared_ptr<ParentComputer> func_graph_parent_;
private:
void AddIntoManaged(const FuncGraphPtr& fg);
void AddIntoManaged(const FuncGraphPtr &fg);
void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction);
void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction);
void AcquireNodes(const std::vector<AnfNodePtr>& nodes);
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr>& nodes);
void ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges,
Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms);
void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction);
void AcquireNodes(const std::vector<AnfNodePtr> &nodes);
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes);
void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges,
Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms);
FuncGraphSet roots_; // managed roots
FuncGraphSet func_graphs_; // managed func graphs
@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
class FuncGraphTransaction {
public:
explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() {
explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() {
MS_EXCEPTION_IF_NULL(manager_);
if (!manager_->IsManaged()) {
MS_LOG(DEBUG) << "The manager is not managed yet";
@ -648,19 +648,19 @@ class FuncGraphTransaction {
~FuncGraphTransaction() { manager_ = nullptr; }
// set parameters of a func graph
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params);
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
// replace old_node with new_node
bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
// set esge, i.e., declare setting node.inputs[key] to value.
void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v);
void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v);
// commit all changes
void Commit();
private:
FuncGraphManager* manager_;
FuncGraphManager *manager_;
std::vector<Change> changes_;
};
@ -668,9 +668,9 @@ class FuncGraphTransaction {
struct ArgsOfSetParams {
FuncGraphPtr func_graph;
std::vector<AnfNodePtr> params;
bool operator==(const ArgsOfSetParams& other) const { return &other == this; }
bool operator==(const ArgsOfSetParams &other) const { return &other == this; }
friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) {
friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) {
os << "[ArgsOfSetParams]";
return os;
}
@ -681,9 +681,9 @@ struct ArgsOfSetEdge {
CNodePtr root_node;
AnfNodePtr new_node;
size_t index;
bool operator==(const ArgsOfSetEdge& other) const { return &other == this; }
bool operator==(const ArgsOfSetEdge &other) const { return &other == this; }
friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) {
friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) {
os << "[ArgsOfSetEdge]";
return os;
}
@ -693,7 +693,7 @@ struct Change {
enum OpName { kTxSetParams, kTxSetEdge };
OpName op;
Any args;
Change(OpName name, const Any& para) : op(name), args(para) {}
Change(OpName name, const Any &para) : op(name), args(para) {}
};
} // namespace mindspore

View File

@ -42,25 +42,25 @@ namespace mindspore {
// generate a graph corresponding to these types.
class MetaFuncGraph : public FuncGraphBase {
public:
explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); }
explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); }
~MetaFuncGraph() override = default;
MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase);
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node);
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node);
// Return normalized versions of the arguments.
// By default, this returns args unchanged.
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const {
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
return args_spec_list;
}
const std::vector<Signature>& signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature>& signatures) { signatures_ = signatures; }
const std::vector<Signature> &signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
// Generate a Graph for the given abstract arguments.
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) {
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr& arg) -> TypePtr {
[](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType();
});
@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase {
}
// Generate a Graph for this type signature.
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) {
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) {
MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types.";
}
@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase {
std::string ToString() const override { return name_; }
std::size_t hash() const override { return tid(); }
virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; }
bool operator==(const Value& other) const override {
virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; }
bool operator==(const Value &other) const override {
if (other.isa<MetaFuncGraph>()) {
return &other == this;
} else {

View File

@ -31,7 +31,7 @@ namespace mindspore {
namespace tensor {
void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
void DataBuf2Contiguous(const py::array &src, py::array *const dest) {
if (dest == nullptr) {
MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!";
}
@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
// MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int>& shape) : data_type_(data_type), shape_(shape) {}
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {}
MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) {
data_type = type_ptr->type_id();
@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
}
}
MetaTensor::MetaTensor(const MetaTensor& meta_tensor)
MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
: Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
if (&meta_tensor == this) {
return *this;
}
@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
return *this;
}
bool MetaTensor::operator==(const MetaTensor& meta_tensor) const {
bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
}
@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
return type_ptr;
}
void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) {
void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) {
DeviceInfo info(format, data_type);
set_device_info(info);
}
@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const {
return oss.str();
}
Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) {
data_type = type_ptr->type_id();
@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
init(data_type_, shape_, &data_);
}
Tensor::Tensor(TypeId data_type, const std::vector<int>& shape) { init(data_type, shape, &data_); }
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); }
Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); }
Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); }
Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type)
Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), device_address_(tensor.device_address()) {
init(tensor.data_, data_type);
}
Tensor& Tensor::operator=(const Tensor& tensor) {
Tensor &Tensor::operator=(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);
dirty_ = tensor.is_dirty();
@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) {
return *this;
}
bool Tensor::operator==(const Tensor& tensor) const {
bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
}
bool Tensor::ValueEqualPy(const py::object& other) const {
bool Tensor::ValueEqualPy(const py::object &other) const {
if (!py::isinstance<Tensor>(other)) {
MS_LOG(WARNING) << "compare other not a tensor";
return false;
@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const {
return ValueEqual(py::cast<Tensor>(other));
}
bool Tensor::ValueEqual(const Tensor& other) const {
bool Tensor::ValueEqual(const Tensor &other) const {
auto equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_);
@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); }
std::vector<int> Tensor::shape_c(void) const { return shape(); }
void* Tensor::data_c(bool writable) {
void *Tensor::data_c(bool writable) {
// operand of bit operation should be unsigned int.
unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
bool is_c_contiguous = (flags != 0) ? true : false;
@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) {
return data_.request(writable).ptr;
}
TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
TypeId Tensor::GetDataType(const py::buffer_info &buf) const {
TypeId data_type = TypeId::kTypeUnknown;
if (buf.format.compare("e") == 0) {
data_type = TypeId::kNumberTypeFloat16;
@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
return data_type;
}
void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
void Tensor::init(const py::array &input, const TypePtr &type_ptr) {
TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) {
data_type = type_ptr->type_id();
@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
init(input, data_type);
}
void Tensor::init(const py::array& input, const TypeId& data_type) {
void Tensor::init(const py::array &input, const TypeId &data_type) {
py::buffer_info buf = input.request();
data_type_ = GetDataType(buf);
@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) {
}
}
void Tensor::init(TypeId data_type, const std::vector<int>& shape, py::array* const data) {
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
data_type_ = data_type;
shape_ = shape;
switch (data_type) {
@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
return data_type_;
}
bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out,
bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out,
const TypeId out_data_type) {
if (out == nullptr) {
return false;
@ -458,7 +458,7 @@ py::array Tensor::data_sync() {
return data_;
}
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
// dtype should define before Tensor, because Tensor init depend dtype
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"))
@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
.def("__repr__", &Tensor::ToStringRepr)
.def("__eq__", &Tensor::ValueEqualPy)
.def(py::pickle(
[](const Tensor& t) { // __getstate__
[](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(t.data());
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}

View File

@ -131,16 +131,16 @@ class MetaTensor : public Value {
// information of a Tensor. The following codes will create a 2x3 float
// param data_type The data type of the tensor.
// param shape The shape of the tensor.
MetaTensor(const TypeId data_type, const std::vector<int>& shape);
MetaTensor(const TypeId data_type, const std::vector<int> &shape);
MetaTensor(const TypePtr& type_ptr, const py::tuple& shape);
MetaTensor(const TypePtr &type_ptr, const py::tuple &shape);
// brief Constructs a MetaTensor object from an existing MetaTensor instance.
//
// The constructed MetaTensor object will have the same data type and shape as the
// meta_tensor.
//
// param meta_tensor An existing MetaTensor object.
MetaTensor(const MetaTensor& meta_tensor);
MetaTensor(const MetaTensor &meta_tensor);
~MetaTensor() override = default;
MS_DECLARE_PARENT(MetaTensor, Value)
@ -149,7 +149,7 @@ class MetaTensor : public Value {
// The constructed MetaTensor object has the same type and shape with meta_tensor.
//
// param meta_tensor An existing MetaTensor object.
virtual MetaTensor& operator=(const MetaTensor& meta_tensor);
virtual MetaTensor &operator=(const MetaTensor &meta_tensor);
// brief Compares two MetaTensor objects.
//
@ -157,7 +157,7 @@ class MetaTensor : public Value {
//
// param meta_tensor The MetaTensor object to be compared.
// return true: If having same type and shape, return true, or return false.
virtual bool operator==(const MetaTensor& meta_tensor) const;
virtual bool operator==(const MetaTensor &meta_tensor) const;
// brief Returns the data type of the tensor in its MetaTensor.
//
@ -193,7 +193,7 @@ class MetaTensor : public Value {
//
// param shape The shape of the tensor.
// return The shape's size.
size_t set_shape(const std::vector<int>& shape) {
size_t set_shape(const std::vector<int> &shape) {
this->shape_ = shape;
return shape_.size();
}
@ -202,9 +202,9 @@ class MetaTensor : public Value {
DeviceInfo device_info() const { return device_info_; }
// Set tensor's device info.
void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; }
void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; }
void SetDeviceInfo(const std::string& format, const TypePtr& data_type);
void SetDeviceInfo(const std::string &format, const TypePtr &data_type);
// Get the size of a given dimension by its index number.
int DimensionSize(size_t index) const;
@ -222,9 +222,9 @@ class MetaTensor : public Value {
}
return hash_value;
}
bool operator==(const Value& other) const override {
bool operator==(const Value &other) const override {
if (other.isa<MetaTensor>()) {
auto other_ = static_cast<const MetaTensor&>(other);
auto other_ = static_cast<const MetaTensor &>(other);
return *this == other_;
} else {
return false;
@ -262,49 +262,49 @@ class Tensor : public MetaTensor {
//
// param type_ptr [TypePty] Data type of the tensor.
// param py_shape [py::tuple] The shape represented by py::tuple of the tensor.
Tensor(const TypePtr& type_ptr, const py::tuple& shape);
Tensor(const TypePtr &type_ptr, const py::tuple &shape);
// brief Constructor for C++.
//
// param data_type [TypeId] Data type of the tensor.
// param shape The shape represented by std::vector<int> of the tensor.
Tensor(TypeId data_type, const std::vector<int>& shape);
Tensor(TypeId data_type, const std::vector<int> &shape);
// brief Constructor for Python.
//
// param input [py::array] Data value of the tensor.
// param data_type [TypeId] Data type of the tensor.
explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr);
explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr);
// brief Constructor
//
// param input [py::list] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr);
explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr);
// brief Constructor
//
// param input [py::tuple] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr);
explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr);
// brief Constructor
//
// param input [py::float_] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr);
explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr);
// brief Constructor
//
// param input [py::int_] the data for tensor
// param data_type [TypeId] data type
explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr);
explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr);
// brief Constructor
//
// param input [Tensor] the data for tensor
// param data_type [TypeId] data type
Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr);
Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr);
~Tensor() override = default;
@ -315,7 +315,7 @@ class Tensor : public MetaTensor {
// The constructed Tensor object has the same type and shape with tensor.
//
// param tensor An existing Tensor object.
Tensor& operator=(const Tensor& tensor);
Tensor &operator=(const Tensor &tensor);
// brief Compares two Tensor objects.
//
@ -324,17 +324,17 @@ class Tensor : public MetaTensor {
//
// param tensor The Tensor object to be compared.
// return true: If having same type, shape and data, return true, or return false.
bool operator==(const Tensor& tensor) const;
bool operator==(const Tensor &tensor) const;
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor& other) const;
bool ValueEqual(const Tensor &other) const;
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqualPy(const py::object& other) const;
bool ValueEqualPy(const py::object &other) const;
bool operator==(const Value& other) const override {
bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor&>(other);
auto other_ = static_cast<const Tensor &>(other);
return *this == other_;
} else {
return false;
@ -375,13 +375,13 @@ class Tensor : public MetaTensor {
//
// param writable true if writable, false if read only
// return The pointer to the object
void* data_c(bool writable = false);
void *data_c(bool writable = false);
// brief Get data type from tensor data.
//
// param buf The buffer info of the py::array data.
// return The [TypeId] of the tensor data.
TypeId GetDataType(const py::buffer_info& buf) const;
TypeId GetDataType(const py::buffer_info &buf) const;
// brief Sets the data type of a tensor.
//
@ -401,23 +401,23 @@ class Tensor : public MetaTensor {
// param input [py::array] the data for tensor
// param data_type [TypeId] data type
// return true if succeed, false if failed.
void init(const py::array& input, const TypeId& data_type);
void init(const py::array& input, const TypePtr& type_ptr);
void init(const py::array &input, const TypeId &data_type);
void init(const py::array &input, const TypePtr &type_ptr);
// brief init tensor attribute
//
// param data_type [TypeId] Data type of the tensor.
// param shape [py::array] The shape of the tensor.
// return true if succeed, false if failed.
void init(TypeId data_type, const std::vector<int>& shape, py::array* data);
void init(TypeId data_type, const std::vector<int> &shape, py::array *data);
bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type);
bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type);
public:
bool is_dirty() const { return dirty_; }
void set_dirty(const bool dirty) { dirty_ = dirty; }
DeviceAddressPtr device_address() const { return device_address_; }
void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; }
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
py::array data_sync();
private:

View File

@ -18,9 +18,9 @@
#include "pipeline/static_analysis/abstract_value.h"
namespace mindspore {
bool Named::operator==(const Value& other) const {
bool Named::operator==(const Value &other) const {
if (other.isa<Named>()) {
auto other_named = static_cast<const Named&>(other);
auto other_named = static_cast<const Named &>(other);
return *this == other_named;
} else {
return false;

View File

@ -27,18 +27,18 @@
namespace mindspore {
class Named : public Value {
public:
explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
Named(const Named& other) : Value(other) {
explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
Named(const Named &other) : Value(other) {
this->name_ = other.name_;
hash_id_ = std::hash<std::string>{}(other.name_);
}
~Named() override = default;
MS_DECLARE_PARENT(Named, Value);
const std::string& name() const { return name_; }
virtual bool operator==(const Named& other) const { return name_ == other.name(); }
bool operator==(const Value& other) const override;
Named& operator=(const Named& other) {
const std::string &name() const { return name_; }
virtual bool operator==(const Named &other) const { return name_ == other.name(); }
bool operator==(const Value &other) const override;
Named &operator=(const Named &other) {
if (&other != this) {
this->type_ = other.type_;
this->name_ = other.name_;
@ -50,7 +50,7 @@ class Named : public Value {
std::size_t Hash() const { return hash_id_; }
std::size_t hash() const override { return hash_id_; }
friend std::ostream& operator<<(std::ostream& os, const Named& nmd) {
friend std::ostream &operator<<(std::ostream &os, const Named &nmd) {
os << nmd.name();
return os;
}

View File

@ -31,7 +31,7 @@
namespace mindspore {
using mindspore::abstract::AbstractFunction;
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) {
abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) {
auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node);
return prim_func;
}
@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() {
return fn;
}
bool Primitive::operator==(const Value& other) const {
bool Primitive::operator==(const Value &other) const {
if (other.isa<Primitive>()) {
auto other_prim = static_cast<const Primitive&>(other);
auto other_prim = static_cast<const Primitive &>(other);
return *this == other_prim;
} else {
return false;
}
}
bool Primitive::operator==(const Primitive& other) const {
bool Primitive::operator==(const Primitive &other) const {
if (name() != other.name()) {
return false;
}
if (attrs_.size() != other.attrs_.size()) {
return false;
}
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr>& item) -> bool {
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
if (item.second == nullptr) {
return false;
}
@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const {
void Primitive::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear();
for (auto& signature : signatures) {
for (auto &signature : signatures) {
std::string name;
SignatureEnumRW rw;
SignatureEnumKind kind;
@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const {
std::ostringstream oss;
oss << "[";
bool is_first = true;
for (auto& attr : attrs_) {
for (auto &attr : attrs_) {
if (is_first) {
is_first = false;
} else {
@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const {
}
py::function PrimitivePy::GetBpropFunction() {
static const char* const get_bprop_func_name = "get_bprop";
static const char *const get_bprop_func_name = "get_bprop";
if (py::hasattr(python_obj_, get_bprop_func_name)) {
py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>();
return fn;
@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() {
}
py::function PrimitivePy::GetComputeFunction() {
static const char* const compute_func_name = "vm_impl";
static const char *const compute_func_name = "vm_impl";
if (py::hasattr(python_obj_, compute_func_name)) {
MS_LOG(INFO) << "" << name() << " compute_func_name";
@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() {
return vm_fn;
}
void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) {
void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
std::string attr_name = name;
ValuePtr converted_ret = nullptr;
if (py::isinstance<py::module>(obj)) {
@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) {
py::dict PrimitivePy::GetAttrDict() {
py::dict attr_dict;
for (auto& attr : attrs_) {
for (auto &attr : attrs_) {
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
}
return attr_dict;
}
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)
.value("builtin", PrimType::kPrimTypeBuiltIn)
@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) {
.value("user_custom", PrimType::kPrimTypeUserCustom);
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
.def(py::init<py::str&, py::object>())
.def(py::init<py::str &, py::object>())
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")

View File

@ -48,25 +48,25 @@ enum PrimType {
class Primitive : public Named {
public:
explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn)
explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name), signatures_(), prim_type_(prim_type) {}
Primitive(const Primitive& prim)
Primitive(const Primitive &prim)
: Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {}
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node);
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
virtual py::function GetBpropFunction();
virtual py::function GetComputeFunction();
Primitive& AddAttr(const std::string& name, const ValuePtr& attr) {
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
attrs_[name] = attr;
return *this;
}
Primitive& SetAttrs(const std::unordered_map<std::string, ValuePtr>& attrs) {
for (auto& attr : attrs) {
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
return *this;
@ -76,21 +76,21 @@ class Primitive : public Named {
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
signatures);
const std::vector<Signature>& signatures() const { return signatures_; }
const std::vector<Signature> &signatures() const { return signatures_; }
void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; }
void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); }
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
ValuePtr GetAttr(const std::string& attrName) const {
ValuePtr GetAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return iter == attrs_.cend() ? nullptr : iter->second;
}
const std::unordered_map<std::string, ValuePtr>& attrs() const { return attrs_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); }
bool HasAttr(const std::string& attrName) const {
bool HasAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
return !(iter == attrs_.cend());
}
@ -103,8 +103,8 @@ class Primitive : public Named {
PrimType prim_type() const { return prim_type_; }
std::string instance_name() const { return instance_name_; }
std::string GetAttrsText() const;
bool operator==(const Value& other) const override;
bool operator==(const Primitive& other) const;
bool operator==(const Value &other) const override;
bool operator==(const Primitive &other) const;
~Primitive() override = default;
protected:
@ -118,18 +118,18 @@ class Primitive : public Named {
class PrimitivePy : public Primitive {
public:
PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {}
PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {}
~PrimitivePy() override = default;
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction() override;
py::function GetComputeFunction() override;
void AddPyAttr(const py::str& name, const py::object& obj);
void AddPyAttr(const py::str &name, const py::object &obj);
py::dict GetAttrDict();
const bool parse_info_ = true;
const py::object& GetPyObj() const { return python_obj_; }
const py::object &GetPyObj() const { return python_obj_; }
bool is_tuple_input_ = false;
private:
@ -138,13 +138,13 @@ class PrimitivePy : public Primitive {
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) {
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
os << *p;
return os;
}
struct PrimitiveEqual {
bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return t1->name() == t2->name();
@ -152,7 +152,7 @@ struct PrimitiveEqual {
};
struct PrimitiveHasher {
std::size_t operator()(PrimitivePtr const& prim) const {
std::size_t operator()(PrimitivePtr const &prim) const {
std::size_t hash = std::hash<std::string>()(prim->name());
return hash;
}

View File

@ -55,8 +55,8 @@ class BoolImm : public Scalar {
bool value() const { return v_; }
bool IsZero() override { return v_ == false; }
bool IsOne() override { return v_ == true; }
bool operator==(const Value& other) const override;
bool operator==(const BoolImm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const BoolImm &other) const;
std::string ToString() const override {
if (v_) {
return "true";
@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool)
class IntergerImm : public Scalar {
public:
IntergerImm() = default;
explicit IntergerImm(const TypePtr& t) : Scalar(t) {}
explicit IntergerImm(const TypePtr &t) : Scalar(t) {}
~IntergerImm() override = default;
MS_DECLARE_PARENT(IntergerImm, Scalar)
};
@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
int8_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const Int8Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const Int8Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
int16_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const Int16Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const Int16Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
int32_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const Int32Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const Int32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
int64_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const Int64Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const Int64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
uint8_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const UInt8Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const UInt8Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
uint16_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const UInt16Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const UInt16Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
uint32_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const UInt32Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const UInt32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm {
bool IsZero() override { return v_ == 0; }
bool IsOne() override { return v_ == 1; }
uint64_t value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const UInt64Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const UInt64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t);
class FloatImm : public Scalar {
public:
FloatImm() = default;
explicit FloatImm(const TypePtr& t) : Scalar(t) {}
explicit FloatImm(const TypePtr &t) : Scalar(t) {}
~FloatImm() override = default;
MS_DECLARE_PARENT(FloatImm, Scalar)
};
@ -312,8 +312,8 @@ class FP32Imm : public FloatImm {
bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
float value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const FP32Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const FP32Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {
@ -338,8 +338,8 @@ class FP64Imm : public FloatImm {
bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
double value() const { return v_; }
bool operator==(const Value& other) const override;
bool operator==(const FP64Imm& other) const;
bool operator==(const Value &other) const override;
bool operator==(const FP64Imm &other) const;
std::string ToString() const override { return std::to_string(v_); }
std::string DumpText() const override {

View File

@ -21,8 +21,8 @@
#include "pipeline/parse/data_converter.h"
namespace mindspore {
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind,
const py::object& arg_default, const SignatureEnumDType& arg_dtype)
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
if (py::isinstance<SignatureEnumKind>(arg_default) &&
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag,
}
}
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind)
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: name(arg_name),
rw(rw_tag),
kind(arg_kind),
default_value(nullptr),
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead)
.value("RW_WRITE", SignatureEnumRW::kRWWrite)

View File

@ -61,9 +61,9 @@ struct Signature {
SignatureEnumKind kind;
ValuePtr default_value; // nullptr for no default value
SignatureEnumDType dtype;
Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind,
const py::object& arg_default, const SignatureEnumDType& arg_dtype);
Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind);
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype);
Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind);
};
} // namespace mindspore

View File

@ -24,7 +24,7 @@
#include "pipeline/static_analysis/abstract_value.h"
namespace mindspore {
const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const {
const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const {
if (dim >= size()) {
MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "].";
}
@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) {
}
}
bool BoolImm::operator==(const Value& other) const {
bool BoolImm::operator==(const Value &other) const {
if (other.isa<BoolImm>()) {
auto other_ = static_cast<const BoolImm&>(other);
auto other_ = static_cast<const BoolImm &>(other);
return *this == other_;
} else {
return false;
}
}
bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; }
bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; }
bool Int8Imm::operator==(const Value& other) const {
bool Int8Imm::operator==(const Value &other) const {
if (other.isa<Int8Imm>()) {
auto other_ = static_cast<const Int8Imm&>(other);
auto other_ = static_cast<const Int8Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; }
bool Int16Imm::operator==(const Value& other) const {
bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; }
bool Int16Imm::operator==(const Value &other) const {
if (other.isa<Int16Imm>()) {
auto other_ = static_cast<const Int16Imm&>(other);
auto other_ = static_cast<const Int16Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; }
bool Int32Imm::operator==(const Value& other) const {
bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; }
bool Int32Imm::operator==(const Value &other) const {
if (other.isa<Int32Imm>()) {
auto other_ = static_cast<const Int32Imm&>(other);
auto other_ = static_cast<const Int32Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; }
bool Int64Imm::operator==(const Value& other) const {
bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; }
bool Int64Imm::operator==(const Value &other) const {
if (other.isa<Int64Imm>()) {
auto other_ = static_cast<const Int64Imm&>(other);
auto other_ = static_cast<const Int64Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; }
bool UInt8Imm::operator==(const Value& other) const {
bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; }
bool UInt8Imm::operator==(const Value &other) const {
if (other.isa<UInt8Imm>()) {
auto other_ = static_cast<const UInt8Imm&>(other);
auto other_ = static_cast<const UInt8Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; }
bool UInt16Imm::operator==(const Value& other) const {
bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; }
bool UInt16Imm::operator==(const Value &other) const {
if (other.isa<UInt16Imm>()) {
auto other_ = static_cast<const UInt16Imm&>(other);
auto other_ = static_cast<const UInt16Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; }
bool UInt32Imm::operator==(const Value& other) const {
bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; }
bool UInt32Imm::operator==(const Value &other) const {
if (other.isa<UInt32Imm>()) {
auto other_ = static_cast<const UInt32Imm&>(other);
auto other_ = static_cast<const UInt32Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; }
bool UInt64Imm::operator==(const Value& other) const {
bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; }
bool UInt64Imm::operator==(const Value &other) const {
if (other.isa<UInt64Imm>()) {
auto other_ = static_cast<const UInt64Imm&>(other);
auto other_ = static_cast<const UInt64Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; }
bool FP32Imm::operator==(const Value& other) const {
bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; }
bool FP32Imm::operator==(const Value &other) const {
if (other.isa<FP32Imm>()) {
auto other_ = static_cast<const FP32Imm&>(other);
auto other_ = static_cast<const FP32Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; }
bool FP64Imm::operator==(const Value& other) const {
bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; }
bool FP64Imm::operator==(const Value &other) const {
if (other.isa<FP64Imm>()) {
auto other_ = static_cast<const FP64Imm&>(other);
auto other_ = static_cast<const FP64Imm &>(other);
return *this == other_;
} else {
return false;
}
}
bool ValueSequeue::operator==(const Value& other) const {
bool ValueSequeue::operator==(const Value &other) const {
if (other.isa<ValueSequeue>()) {
auto other_ = static_cast<const ValueSequeue&>(other);
auto other_ = static_cast<const ValueSequeue &>(other);
return *this == other_;
} else {
return false;
}
}
bool ValueSequeue::operator==(const ValueSequeue& other) const {
bool ValueSequeue::operator==(const ValueSequeue &other) const {
if (other.elements_.size() != elements_.size()) {
return false;
}
return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(),
[](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; });
[](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; });
}
std::string ValueSequeue::ToString() const {
std::ostringstream buffer;
bool begin = true;
for (auto& attr : elements_) {
for (auto &attr : elements_) {
if (!begin) {
buffer << ", ";
} else {
@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const {
return oss.str();
}
bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; }
bool StringImm::operator==(const Value& other) const {
bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; }
bool StringImm::operator==(const Value &other) const {
if (other.isa<StringImm>()) {
auto other_ = static_cast<const StringImm&>(other);
auto other_ = static_cast<const StringImm &>(other);
return *this == other_;
} else {
return false;
}
}
bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; }
bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; }
bool RefKey::operator==(const Value& other) const {
bool RefKey::operator==(const Value &other) const {
if (other.isa<RefKey>()) {
auto other_ = static_cast<const RefKey&>(other);
auto other_ = static_cast<const RefKey &>(other);
return *this == other_;
} else {
return false;
}
}
bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; }
bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; }
bool AnyValue::operator==(const Value& other) const {
bool AnyValue::operator==(const Value &other) const {
if (other.isa<AnyValue>()) {
return true;
} else {
@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstr
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) {
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() {
abstract::AbstractBasePtr ValueList::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) {
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const {
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
}
bool ValueSlice::operator==(const Value& other) const {
bool ValueSlice::operator==(const Value &other) const {
if (other.isa<ValueSlice>()) {
auto other_ = static_cast<const ValueSlice&>(other);
auto other_ = static_cast<const ValueSlice &>(other);
return *this == other_;
} else {
return false;
}
}
bool ValueSlice::operator==(const ValueSlice& other) const {
bool ValueSlice::operator==(const ValueSlice &other) const {
MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_);
@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const {
return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()});
}
bool KeywordArg::operator==(const Value& other) const {
bool KeywordArg::operator==(const Value &other) const {
if (other.isa<KeywordArg>()) {
auto other_ = static_cast<const KeywordArg&>(other);
auto other_ = static_cast<const KeywordArg &>(other);
return *this == other_;
} else {
return false;
}
}
bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); }
bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); }
std::string KeywordArg::ToString() const {
std::ostringstream buffer;
@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() {
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
}
const ValuePtr ValueDictionary::operator[](const std::string& key) const {
const ValuePtr ValueDictionary::operator[](const std::string &key) const {
auto it = std::find_if(key_values_.begin(), key_values_.end(),
[key](const std::pair<std::string, ValuePtr>& item) { return item.first == key; });
[key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; });
if (it == key_values_.end()) {
MS_LOG(EXCEPTION) << "The key " << key << " is not in the map";
}
return it->second;
}
bool ValueDictionary::operator==(const Value& other) const {
bool ValueDictionary::operator==(const Value &other) const {
if (other.isa<ValueDictionary>()) {
auto other_ = static_cast<const ValueDictionary&>(other);
auto other_ = static_cast<const ValueDictionary &>(other);
return *this == other_;
} else {
return false;
}
}
bool ValueDictionary::operator==(const ValueDictionary& other) const {
bool ValueDictionary::operator==(const ValueDictionary &other) const {
if (key_values_.size() != other.key_values_.size()) {
return false;
}
@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
(void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, ValuePtr>& item) { return std::make_pair(item.first, item.second->ToAbstract()); });
[](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
return std::make_shared<abstract::AbstractDictionary>(kv);
}
REGISTER_PYBIND_DEFINE(
RefKey, ([](const py::module* m) {
RefKey, ([](const py::module *m) {
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
}));
} // namespace mindspore

View File

@ -35,19 +35,19 @@
namespace mindspore {
class ValueSequeue : public Value {
public:
explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) {
explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) {
TypePtrList t_list;
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) {
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->type();
});
TypePtr t = std::make_shared<Tuple>(t_list);
type_ = t;
}
ValueSequeue(const std::initializer_list<ValuePtr>& elements) : elements_(elements.begin(), elements.end()) {
ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) {
TypePtrList t_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list),
[](const ValuePtr& ele) { return ele->type(); });
[](const ValuePtr &ele) { return ele->type(); });
TypePtr t = std::make_shared<Tuple>(t_list);
type_ = t;
}
@ -56,10 +56,10 @@ class ValueSequeue : public Value {
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); }
std::size_t size() const { return elements_.size(); }
bool erase(size_t idx);
const ValuePtr operator[](const std::size_t& dim) const;
const ValuePtrList& value() const { return elements_; }
bool operator==(const Value& other) const override;
bool operator==(const ValueSequeue& other) const;
const ValuePtr operator[](const std::size_t &dim) const;
const ValuePtrList &value() const { return elements_; }
bool operator==(const Value &other) const override;
bool operator==(const ValueSequeue &other) const;
std::string ToString() const override;
std::string DumpText() const override;
@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr<ValueSequeue>;
class ValueTuple : public ValueSequeue {
public:
explicit ValueTuple(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {}
ValueTuple(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {}
explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
~ValueTuple() override = default;
MS_DECLARE_PARENT(ValueTuple, ValueSequeue)
abstract::AbstractBasePtr ToAbstract() override;
@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr<ValueTuple>;
class ValueList : public ValueSequeue {
public:
explicit ValueList(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {}
ValueList(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {}
explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
~ValueList() override = default;
MS_DECLARE_PARENT(ValueList, ValueSequeue)
abstract::AbstractBasePtr ToAbstract() override;
@ -94,7 +94,7 @@ class ValueList : public ValueSequeue {
};
using ValueListPtr = std::shared_ptr<ValueList>;
inline ValuePtr MakeValue(const std::vector<ValuePtr>& v) { return std::make_shared<ValueTuple>(v); }
inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); }
inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); }
template <typename T>
@ -103,7 +103,7 @@ template <typename T, typename A>
struct is_vector<std::vector<T, A>> : public std::true_type {};
template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type>
ValuePtr MakeValue(const T& vec) {
ValuePtr MakeValue(const T &vec) {
std::vector<ValuePtr> list;
(void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); });
return std::make_shared<ValueTuple>(list);
@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) {
class ValueSlice : public Value {
public:
ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step)
ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step)
: start_(start), stop_(stop), step_(step) {}
~ValueSlice() override = default;
MS_DECLARE_PARENT(ValueSlice, Value)
std::size_t hash() const override;
bool operator==(const Value& other) const override;
bool operator==(const ValueSlice& other) const;
bool operator==(const Value &other) const override;
bool operator==(const ValueSlice &other) const;
std::string ToString() const override;
@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr<ValueSlice>;
class KeywordArg : public Value {
public:
KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {}
KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {}
~KeywordArg() override = default;
MS_DECLARE_PARENT(KeywordArg, Value)
std::size_t hash() const override;
ValuePtr get_value() const { return value_; }
bool operator==(const Value& other) const override;
bool operator==(const KeywordArg& other) const;
bool operator==(const Value &other) const override;
bool operator==(const KeywordArg &other) const;
std::string ToString() const override;
@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr<KeywordArg>;
class ValueDictionary : public Value {
public:
explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>>& key_values) : key_values_(key_values) {}
explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {}
~ValueDictionary() override = default;
MS_DECLARE_PARENT(ValueDictionary, Value)
std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); }
std::size_t size() const { return key_values_.size(); }
const ValuePtr operator[](const std::string& key) const;
const std::vector<std::pair<std::string, ValuePtr>>& value() const { return key_values_; }
bool operator==(const Value& other) const override;
bool operator==(const ValueDictionary& other) const;
const ValuePtr operator[](const std::string &key) const;
const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; }
bool operator==(const Value &other) const override;
bool operator==(const ValueDictionary &other) const;
std::string ToString() const override {
std::ostringstream buffer;
std::vector<std::string> keys;
std::vector<ValuePtr> values;
for (const auto& kv : key_values_) {
for (const auto &kv : key_values_) {
keys.push_back(kv.first);
values.push_back(kv.second);
}
buffer << "(Dict: "
<< " keys:(";
for (const auto& key : keys) {
for (const auto &key : keys) {
buffer << key << ", ";
}
buffer << ") values:(";
for (const auto& value : values) {
for (const auto &value : values) {
MS_EXCEPTION_IF_NULL(value);
buffer << value->DumpText() << ", ";
}
@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>;
class StringImm : public Value {
public:
explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
~StringImm() override = default;
MS_DECLARE_PARENT(StringImm, Value)
std::size_t hash() const override { return hash_; }
const std::string& value() const { return str_; }
bool operator==(const Value& other) const override;
bool operator==(const StringImm& other) const;
const std::string &value() const { return str_; }
bool operator==(const Value &other) const override;
bool operator==(const StringImm &other) const;
abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override { return str_; }
@ -218,18 +218,18 @@ class StringImm : public Value {
};
using StringImmPtr = std::shared_ptr<StringImm>;
IMM_TRAITS(StringImmPtr, std::string)
IMM_TRAITS(StringImmPtr, const char*)
IMM_TRAITS(StringImmPtr, const char *)
class RefKey : public Value {
public:
explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
~RefKey() override = default;
MS_DECLARE_PARENT(RefKey, Value)
std::size_t hash() const override { return hash_; }
const std::string& tag() const { return tag_; }
bool operator==(const Value& other) const override;
bool operator==(const RefKey& other) const;
const std::string &tag() const { return tag_; }
bool operator==(const Value &other) const override;
bool operator==(const RefKey &other) const;
abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override { return "RefKey[" + tag_ + "]"; }
@ -251,13 +251,13 @@ class AnyValue : public Value {
~AnyValue() override = default;
MS_DECLARE_PARENT(AnyValue, Value)
std::size_t hash() const override { return tid(); }
bool operator==(const Value& other) const override;
bool operator==(const Value &other) const override;
abstract::AbstractBasePtr ToAbstract() override;
};
extern const ValuePtr kAnyValue;
template <>
inline const char* GetValue(const ValuePtr& value) {
inline const char *GetValue(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(EXCEPTION) << "Value is nullptr";
}
@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) {
template <typename T, typename S = typename std::decay<T>::type,
typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type>
std::vector<U> GetValue(const ValuePtr& value) {
std::vector<U> GetValue(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(EXCEPTION) << "Value is nullptr";
}
@ -280,21 +280,21 @@ std::vector<U> GetValue(const ValuePtr& value) {
<< ">";
}
std::vector<U> rets;
const std::vector<ValuePtr>& vals = value->cast<ValueSequeuePtr>()->value();
const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value();
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets),
[](const ValuePtr& v) { return GetValue<U>(v); });
[](const ValuePtr &v) { return GetValue<U>(v); });
return rets;
}
inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared<ValueNode>(t); }
inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); }
template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
inline ValueNodePtr NewValueNode(const std::shared_ptr<T>& x) {
inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) {
return NewValueNode(MakeValue(x));
}
template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
inline ValueNodePtr NewValueNode(const T& x) {
inline ValueNodePtr NewValueNode(const T &x) {
return NewValueNode(MakeValue(x));
}
} // namespace mindspore

View File

@ -22,15 +22,15 @@
#include "optimizer/opt.h"
namespace mindspore {
using VisitFuncType = std::function<void(const AnfNodePtr&)>;
using VisitFuncType = std::function<void(const AnfNodePtr &)>;
class AnfVisitor {
public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&);
virtual void Visit(const AnfNodePtr&);
virtual void Visit(const CNodePtr&);
virtual void Visit(const ValueNodePtr&);
virtual void Visit(const ParameterPtr&);
VisitFuncType Match(const PrimitivePtr&, const std::vector<opt::PredicateFuncType>& = {});
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &);
virtual void Visit(const AnfNodePtr &);
virtual void Visit(const CNodePtr &);
virtual void Visit(const ValueNodePtr &);
virtual void Visit(const ParameterPtr &);
VisitFuncType Match(const PrimitivePtr &, const std::vector<opt::PredicateFuncType> & = {});
virtual ~AnfVisitor() = default;
};
} // namespace mindspore

View File

@ -26,12 +26,12 @@
namespace mindspore {
namespace kernel {
namespace {
void FilterInvaildKernelInfo(const CNodePtr& kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) {
void FilterInvaildKernelInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
[&](const std::shared_ptr<kernel::KernelBuildInfo>& kernel_build_info) {
[&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
});
@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node,
}
}
} // namespace
void KernelQuery(const CNodePtr& kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);

View File

@ -38,11 +38,11 @@ class OpAttr {
std::string value() const { return value_; }
std::string default_value() const { return default_value_; }
void set_name(const std::string& name) { name_ = name; }
void set_param_type(const std::string& param_type) { param_type_ = param_type; }
void set_type(const std::string& type) { type_ = type; }
void set_value(const std::string& value) { value_ = value; }
void set_default_value(const std::string& default_value) { default_value_ = default_value; }
void set_name(const std::string &name) { name_ = name; }
void set_param_type(const std::string &param_type) { param_type_ = param_type; }
void set_type(const std::string &type) { type_ = type; }
void set_value(const std::string &value) { value_ = value; }
void set_default_value(const std::string &default_value) { default_value_ = default_value; }
private:
std::string name_;
@ -67,13 +67,13 @@ class OpIOInfo {
std::vector<std::string> formats() const { return formats_; }
void set_index(const int index) { index_ = index; }
void set_name(const std::string& name) { name_ = name; }
void set_name(const std::string &name) { name_ = name; }
void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
void set_param_type(const std::string& param_type) { param_type_ = param_type; }
void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; }
void set_shape(const std::string& shape) { shape_ = shape; }
void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; }
void set_formats(const std::vector<std::string>& formats) { formats_ = formats; }
void set_param_type(const std::string &param_type) { param_type_ = param_type; }
void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; }
void set_shape(const std::string &shape) { shape_ = shape; }
void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; }
void set_formats(const std::vector<std::string> &formats) { formats_ = formats; }
private:
int index_ = 0;
@ -104,24 +104,24 @@ class OpInfo {
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
const std::unordered_map<size_t, size_t>& ref_infos() const { return ref_infos_; }
const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
void set_op_name(const std::string& op_name) { op_name_ = op_name; }
void set_op_name(const std::string &op_name) { op_name_ = op_name; }
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; }
void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; }
void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; }
void set_async_flag(const bool async_flag) { async_flag_ = async_flag; }
void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; }
void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; }
void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; }
void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; }
void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; }
void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); }
void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& inputs) { inputs_ptr_ = inputs; }
void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& outputs) { outputs_ptr_ = outputs; }
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; }
void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { outputs_ptr_ = outputs; }
bool is_ref() const { return !ref_infos_.empty(); }
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }

View File

@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) {
return "unknow";
}
}
bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) {
bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
bool ret = false;
try {
auto op_json = nlohmann::json::parse(json_string);
@ -88,13 +88,13 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path)
if (!ret) {
MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string;
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(DEBUG) << "get op_json elements failed:" << e.what();
}
return ret;
}
void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) {
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
op_info->set_compute_cost(obj.at(kComputeCost));
@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p
}
}
bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type,
const std::string& impl_path) {
bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type,
const std::string &impl_path) {
std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>();
MS_EXCEPTION_IF_NULL(op_info);
op_info->set_op_name(obj.at(kOpName));
@ -120,7 +120,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
DecodeTBESpecificInfo(obj, op_info);
}
auto attrs = obj.at(kAttr);
for (const auto& attr : attrs) {
for (const auto &attr : attrs) {
if (!DecodeAttr(attr, imply_type, op_info)) {
MS_LOG(DEBUG) << "DecodeAttr Failed";
return false;
@ -131,14 +131,14 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
dtype_format = obj.at(kDtypeFormat);
}
auto inputs = obj.at(kIputs);
for (const auto& input : inputs) {
for (const auto &input : inputs) {
if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false;
}
}
auto outputs = obj.at(kOutputs);
for (const auto& output : outputs) {
for (const auto &output : outputs) {
if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) {
MS_LOG(DEBUG) << "DecodeInputOutput Failed";
return false;
@ -156,8 +156,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI
return true;
}
bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo>& op_info) {
bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info);
bool ret = true;
try {
@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
op_attr->set_default_value(obj.at(kDefaultValue));
}
op_info->add_attrs_ptr(op_attr);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what();
ret = false;
}
return ret;
}
bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
size_t index) {
bool ret = true;
try {
std::vector<std::string> dtype;
std::vector<std::string> format;
for (const auto& it : dtype_format) {
for (const auto &it : dtype_format) {
dtype.emplace_back(it[index][0]);
format.emplace_back(it[index][1]);
}
op_io->set_dtypes(dtype);
op_io->set_formats(format);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what();
ret = false;
}
return ret;
}
bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) {
bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) {
bool ret = true;
try {
std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>();
@ -243,14 +243,14 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply
} else if (io_type == kOutput) {
op_info->add_outputs_ptr(op_io);
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what();
ret = false;
}
return ret;
}
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) {
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice);
@ -260,7 +260,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
<< ", current op num:" << op_info_.size();
return nullptr;
}
for (const auto& op_info : op_info_) {
for (const auto &op_info : op_info_) {
MS_EXCEPTION_IF_NULL(op_info);
if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) {
return op_info;
@ -271,14 +271,14 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im
return nullptr;
}
bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) {
bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info);
const auto& output_infos = op_info->outputs_ptr();
const auto& input_infos = op_info->inputs_ptr();
const auto &output_infos = op_info->outputs_ptr();
const auto &input_infos = op_info->inputs_ptr();
for (size_t out_index = 0; out_index < output_infos.size(); out_index++) {
const auto& out_name = output_infos[out_index]->name();
const auto &out_name = output_infos[out_index]->name();
for (size_t in_index = 0; in_index < input_infos.size(); in_index++) {
const auto& in_name = input_infos[in_index]->name();
const auto &in_name = input_infos[in_index]->name();
if (out_name == in_name) {
if (op_info->has_ref_index(out_index)) {
MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info";
@ -293,9 +293,9 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) {
return true;
}
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo>& op_info) {
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info);
for (const auto& exist_op_info : op_info_) {
for (const auto &exist_op_info : op_info_) {
MS_EXCEPTION_IF_NULL(exist_op_info);
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() &&
exist_op_info->impl_path() != op_info->impl_path()) {

View File

@ -28,23 +28,23 @@ class OpLib {
public:
OpLib() = default;
virtual ~OpLib() = default;
bool RegOp(const std::string& json_string, const std::string& impl_path);
static std::shared_ptr<OpInfo> FindOp(const std::string& op_name, OpImplyType imply_type);
bool RegOp(const std::string &json_string, const std::string &impl_path);
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
protected:
static std::vector<std::shared_ptr<OpInfo>> op_info_;
private:
static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path);
static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo>& op_info);
static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io,
static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path);
static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type,
const std::shared_ptr<OpInfo> &op_info);
static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
size_t index);
static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info);
static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info);
static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info);
static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info);
static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info);
static bool CheckRepetition(const std::shared_ptr<OpInfo> &op_info);
};
} // namespace kernel
} // namespace mindspore

View File

@ -19,6 +19,6 @@
namespace mindspore {
// cppcheck-suppress unusedFunction
std::string set_version(const std::string& version) { return version; }
std::string set_version(const std::string &version) { return version; }
} // namespace mindspore

View File

@ -42,11 +42,11 @@ struct OpMergedInfo {
};
using GenAttrFuncType =
std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto*, const PrimitivePtr&)>;
std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto *, const PrimitivePtr &)>;
template <typename T, size_t rep_cnt = 0>
void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
auto casted_value = dyn_cast<T>(value);
if (casted_value == nullptr) {
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy
}
template <size_t beg_idx = 0>
void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
auto tuple_ptr = dyn_cast<ValueTuple>(value);
if (tuple_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed.";
@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib
attr_proto->set_type(attr_type);
}
void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType,
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
auto attr_value = GetValue<std::string>(value);
if (attr_value == "VALID") {
@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType
class OpAttrInfo {
public:
OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name,
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr)
OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name,
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr)
: attr_name_(attr_name),
onnx_attr_name_(onnx_attr_name),
onnx_attr_type_(onnx_attr_type),
fn_gen_attr_(fn_gen_attr) {}
~OpAttrInfo() {}
const std::string& attr_name() const { return attr_name_; }
const std::string& onnx_attr_name() const { return onnx_attr_name_; }
const std::string &attr_name() const { return attr_name_; }
const std::string &onnx_attr_name() const { return onnx_attr_name_; }
onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; }
GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; }
@ -134,27 +134,27 @@ class OpAttrInfo {
class OpNameInfo {
public:
OpNameInfo& set_op_type(const std::string& op_type) {
OpNameInfo &set_op_type(const std::string &op_type) {
op_type_ = op_type;
return *this;
}
const std::string& op_type() const { return op_type_; }
const std::string &op_type() const { return op_type_; }
OpNameInfo& set_onnx_type(const std::string& onnx_type) {
OpNameInfo &set_onnx_type(const std::string &onnx_type) {
onnx_type_ = onnx_type;
return *this;
}
const std::string& onnx_type() const { return onnx_type_; }
const std::string &onnx_type() const { return onnx_type_; }
OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name,
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) {
OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name,
onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) {
op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr));
return *this;
}
const std::vector<OpAttrInfo>& op_attrs() const { return op_attrs_; }
const std::vector<OpAttrInfo> &op_attrs() const { return op_attrs_; }
private:
std::string op_type_; // operator type of MindSpore
@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE(
.Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>)
.Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
.Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING,
[](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto,
const PrimitivePtr& prim) {
[](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto,
const PrimitivePtr &prim) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
auto attr_value = GetValue<std::string>(value);
if (attr_value == "valid") {
@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax,
SetAttrValueToProto<Int32Imm>)
.Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT,
[](ValuePtr, onnx::AttributeProto_AttributeType,
onnx::AttributeProto* const attr_proto, const PrimitivePtr&) {
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
attr_proto->set_i(0);
}))
@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
void RegisterOpConverters(const std::function<void(OpNameInfo&&)>& fn) {
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)());
fn(OP_CONVERT_FUNCTION_NAME(Mul)());
@ -265,16 +265,16 @@ class OpConvertRegistry {
public:
~OpConvertRegistry() { Clear(); }
static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; }
static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; }
static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); }
static OpConvertRegistry& GetSingleton() {
static OpConvertRegistry &GetSingleton() {
static OpConvertRegistry registry = OpConvertRegistry();
return registry;
}
static const std::unordered_map<std::string, OpNameInfo>& GetOpConvertMap() { return GetSingleton().op_map_; }
static const std::unordered_map<std::string, OpNameInfo> &GetOpConvertMap() { return GetSingleton().op_map_; }
void Clear() noexcept { op_map_.clear(); }
@ -289,59 +289,59 @@ class OnnxExporter {
OnnxExporter() {}
~OnnxExporter() {}
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
private:
void InitModelInfo();
void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto);
void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto);
void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto);
void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto);
size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr,
const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs,
onnx::GraphProto* graph_proto);
size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
onnx::GraphProto *graph_proto);
static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false);
void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto);
void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false);
void SetTensorProtoInfo(const ParameterPtr &param, onnx::TensorProto *tensor_proto);
void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr);
void ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* graph_proto);
void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr);
void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* graph_proto);
void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto);
void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto);
void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* graph_proto);
void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* graph_proto);
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* graph_proto);
void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* graph_proto);
void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto);
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* graph_proto);
std::string GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* const graph_proto);
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
std::string GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto);
void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto);
void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto);
void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto);
void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto);
size_t AllocateNodeIndex() { return ++onnx_node_index_; }
void ResetNodeIndex() { onnx_node_index_ = 0; }
static int GetInt32Value(const AnfNodePtr& node) {
static int GetInt32Value(const AnfNodePtr &node) {
auto value_node_ptr = dyn_cast<ValueNode>(node);
MS_EXCEPTION_IF_NULL(value_node_ptr);
return GetValue<int>(value_node_ptr->value());
@ -352,7 +352,7 @@ class OnnxExporter {
size_t onnx_node_index_ = 0;
};
std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) {
std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return "";
}
@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) {
OpConvertRegistry::GetSingleton().Clear();
OpConvertRegistry::RegisterAllOpConverters();
InitModelInfo();
onnx::GraphProto* graph_proto = model_.mutable_graph();
onnx::GraphProto *graph_proto = model_.mutable_graph();
ExportFuncGraph(func_graph, graph_proto);
return model_.SerializeAsString();
}
@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() {
model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
model_.set_producer_name("MindSpore");
model_.set_producer_version("1.0");
onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import();
onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import();
opset_proto->set_version(9);
}
void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
std::map<AnfNodePtr, size_t> node_map;
onnx_node_index_ = func_graph->parameters().size();
@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr
ExportNodes(func_graph, &node_map, graph_proto);
}
void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) {
for (auto& param : func_graph->parameters()) {
void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) {
for (auto &param : func_graph->parameters()) {
const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
if (param_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
}
onnx::ValueInfoProto* input_proto = graph_proto->add_input();
onnx::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_ptr->ToString());
SetValueInfoType(param_ptr, input_proto);
@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP
continue;
}
// parameter with default value is an ONNX initializer
onnx::TensorProto* initializer_proto = graph_proto->add_initializer();
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_ptr->ToString());
SetTensorProtoInfo(param_ptr, initializer_proto);
// set value for initializer
@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) {
return iter->second;
}
void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) {
void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) {
auto dtype = node->Type();
auto shape = node->Shape();
onnx::TypeProto* type_proto = value_proto->mutable_type();
onnx::TypeProto *type_proto = value_proto->mutable_type();
if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = dyn_cast<TensorType>(dtype);
auto elem_type = tensor->element();
const auto& dims = dyn_cast<abstract::Shape>(shape)->shape();
const auto &dims = dyn_cast<abstract::Shape>(shape)->shape();
// output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64
auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id());
type_proto->mutable_tensor_type()->set_elem_type(type);
for (const auto& dim : dims) {
for (const auto &dim : dims) {
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
}
}
}
void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) {
void OnnxExporter::SetTensorProtoInfo(const ParameterPtr &param, onnx::TensorProto *const tensor_proto) {
auto dtype = param->Type();
auto shape = param->Shape();
if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro
auto tensor = dyn_cast<TensorType>(dtype);
auto elem_type = tensor->element();
const auto& dims = dyn_cast<abstract::Shape>(shape)->shape();
const auto &dims = dyn_cast<abstract::Shape>(shape)->shape();
tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id()));
for (const auto& dim : dims) {
for (const auto &dim : dims) {
tensor_proto->add_dims(dim);
}
}
void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr) {
std::unordered_map<AnfNodePtr, OpMergedInfo>& op_merged_infos = *op_merged_infos_ptr;
void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) {
std::unordered_map<AnfNodePtr, OpMergedInfo> &op_merged_infos = *op_merged_infos_ptr;
for (auto& node : nodes) {
for (auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto
// if the key `input` does not exist, just create a new one
op_merged_infos[cnode].referred_count += 1;
}
for (auto& input : cnode->inputs()) {
for (auto &input : cnode->inputs()) {
if (!input->isa<CNode>()) {
continue;
}
@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto
* | +-- Parameter
* | `-- ValueNode
*/
void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos;
MatchAndMark(func_graph, nodes, &op_merged_infos);
for (const AnfNodePtr& node : nodes) {
for (const AnfNodePtr &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodeP
}
}
void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_shape = node->input(2);
std::string name_shape;
if (input_shape->isa<ValueNode>()) {
auto const_node_idx = AllocateNodeIndex();
(*node_map_ptr)[input_shape] = const_node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
name_shape = std::to_string(const_node_idx);
node_proto->add_output(name_shape);
node_proto->set_op_type("Constant");
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimReshape->name());
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(name_x);
node_proto->add_input(name_shape);
}
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_axis = node->input(2);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimReduceMean->name());
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
if (input_axis->isa<ValueNode>()) {
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("axes");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons
}
}
void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_type = node->input(2);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimCast->name());
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_data);
if (input_type->isa<ValueNode>()) {
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("to");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
auto type_value = dyn_cast<ValueNode>(input_type)->value();
@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod
}
}
void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo
// format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2]
if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) {
auto node_idx = AllocateNodeIndex();
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Unsqueeze");
node_proto->add_output(std::to_string(node_idx));
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
attr_proto->set_name("axes");
attr_proto->add_ints(1);
@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("PRelu");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_x);
node_proto->add_input(input_slope);
}
void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
if (node->IsApply(prim::kPrimReshape)) {
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
}
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map<AnfNodePtr, size_t>* node_map_ptr,
const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs,
onnx::GraphProto* const graph_proto) {
size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map<AnfNodePtr, size_t> *node_map_ptr,
const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
onnx::GraphProto *const graph_proto) {
auto op_map = OpConvertRegistry::GetOpConvertMap();
auto op_iter = op_map.find(prim->name());
if (op_iter == op_map.end()) {
MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map";
}
const OpNameInfo& op_convert_info = op_iter->second;
const OpNameInfo &op_convert_info = op_iter->second;
auto node_idx = AllocateNodeIndex();
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->add_output(std::to_string(node_idx));
node_proto->set_op_type(op_convert_info.onnx_type());
// Set inputs
for (const auto& input : inputs) {
for (const auto &input : inputs) {
auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto);
node_proto->add_input(input_name);
}
// Set node attribute
for (const OpAttrInfo& attr : op_convert_info.op_attrs()) {
const std::string& attr_name = attr.attr_name();
for (const OpAttrInfo &attr : op_convert_info.op_attrs()) {
const std::string &attr_name = attr.attr_name();
ValuePtr attr_value = nullptr;
if (!attr_name.empty()) {
attr_value = prim->GetAttr(attr_name);
@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name;
}
}
onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute();
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name(attr.onnx_attr_name());
attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim);
}
return node_idx;
}
void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto conv_node = dyn_cast<CNode>(node->input(1));
auto input_x = conv_node->input(1); // conv input x
auto input_w = conv_node->input(2); // conv weight(filter)
@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto);
}
void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto matmul_node = dyn_cast<CNode>(node->input(1));
auto input_x = matmul_node->input(1); // matmul input x
auto input_y = matmul_node->input(2); // matmul input y
@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
}
void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto batch_norm_node = dyn_cast<CNode>(node->input(1));
PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value());
@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
}
void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node,
std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) {
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
if (node->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
}
AnfNodePtr arg = node->input(1);
std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto);
onnx::ValueInfoProto* output_proto = graph_proto->add_output();
onnx::ValueInfoProto *output_proto = graph_proto->add_output();
output_proto->set_name(name);
SetValueInfoType(arg, output_proto, false);
}
std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr,
onnx::GraphProto* const graph_proto) {
std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
if (node->isa<CNode>()) {
auto iter = node_map_ptr->find(node);
if (iter == node_map_ptr->end()) {
@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN
(*node_map_ptr)[node] = node_idx;
std::string node_name = std::to_string(node_idx);
onnx::NodeProto* node_proto = graph_proto->add_node();
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->add_output(node_name);
SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto);
@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN
MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name();
}
void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) {
void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) {
auto tuple_ptr = dyn_cast<ValueTuple>(value);
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (tuple_ptr->size() == 0) {
@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto
}
}
void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) {
void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) {
node_proto->set_op_type("Constant");
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
}
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) {
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
OnnxExporter exporter;
return exporter.GetOnnxProtoString(func_graph);
}

View File

@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown };
// Whether has a T type data in AnyPtrList.
template <class T>
bool HasType(const AnyPtrList& list) {
bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is<T>(); });
bool HasType(const AnyPtrList &list) {
bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
return ret;
}
DataType InferType(const AnyPtrList& list) {
DataType InferType(const AnyPtrList &list) {
if (HasType<double>(list)) {
return DataType::kDouble;
} else if (HasType<float>(list)) {
@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) {
}
#define SCALAR_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList& list) { \
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
do { \
if (list.size() < 2) { \
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
@ -223,7 +223,7 @@ SCALAR_OP(Pow)
SCALAR_OP(Floordiv)
#define LOGIC_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList& list) { \
ValuePtr Scalar##op_t(const ValuePtrList &list) { \
if (list.size() < 2) { \
MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
} \
@ -274,7 +274,7 @@ LOGIC_OP(Ne)
LOGIC_OP(Le)
LOGIC_OP(Ge)
ValuePtr ScalarUAdd(const ValuePtrList& list) {
ValuePtr ScalarUAdd(const ValuePtrList &list) {
if (list.size() != 1) {
MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size();
}
@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) {
return x;
}
ValuePtr ScalarUSub(const ValuePtrList& list) {
ValuePtr ScalarUSub(const ValuePtrList &list) {
if (list.size() != 1) {
MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size();
}
@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
}
ValuePtr ScalarLog(const ValuePtrList& list) {
ValuePtr ScalarLog(const ValuePtrList &list) {
if (list.empty()) {
MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
}
@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
}
ValuePtr BoolNot(const ValuePtrList& list) {
ValuePtr BoolNot(const ValuePtrList &list) {
if (list.empty()) {
MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
}
@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
}
ValuePtr BoolAnd(const ValuePtrList& list) {
ValuePtr BoolAnd(const ValuePtrList &list) {
if (list.size() < 2) {
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
}
@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
}
ValuePtr BoolOr(const ValuePtrList& list) {
ValuePtr BoolOr(const ValuePtrList &list) {
if (list.size() < 2) {
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
}
@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) {
MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
}
ValuePtr BoolEq(const ValuePtrList& list) {
ValuePtr BoolEq(const ValuePtrList &list) {
if (list.size() < 2) {
MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
}

View File

@ -29,29 +29,29 @@ namespace prim {
using Any = mindspore::Any;
using AnyPtrList = std::vector<std::shared_ptr<Any>>;
using ValuePtrList = std::vector<ValuePtr>;
using OpsFunction = std::function<Any(const AnyPtrList&)>;
using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr>&)>;
using OpsFunction = std::function<Any(const AnyPtrList &)>;
using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr> &)>;
ValuePtr ScalarAdd(const ValuePtrList& list);
ValuePtr ScalarSub(const ValuePtrList& list);
ValuePtr ScalarMul(const ValuePtrList& list);
ValuePtr ScalarDiv(const ValuePtrList& list);
ValuePtr ScalarMod(const ValuePtrList& list);
ValuePtr ScalarPow(const ValuePtrList& list);
ValuePtr ScalarFloordiv(const ValuePtrList& list);
ValuePtr ScalarUAdd(const ValuePtrList& list);
ValuePtr ScalarUSub(const ValuePtrList& list);
ValuePtr ScalarLog(const ValuePtrList& list);
ValuePtr ScalarEq(const ValuePtrList& list);
ValuePtr ScalarLt(const ValuePtrList& list);
ValuePtr ScalarGt(const ValuePtrList& list);
ValuePtr ScalarNe(const ValuePtrList& list);
ValuePtr ScalarLe(const ValuePtrList& list);
ValuePtr ScalarGe(const ValuePtrList& list);
ValuePtr BoolNot(const ValuePtrList& list);
ValuePtr BoolAnd(const ValuePtrList& list);
ValuePtr BoolOr(const ValuePtrList& list);
ValuePtr BoolEq(const ValuePtrList& list);
ValuePtr ScalarAdd(const ValuePtrList &list);
ValuePtr ScalarSub(const ValuePtrList &list);
ValuePtr ScalarMul(const ValuePtrList &list);
ValuePtr ScalarDiv(const ValuePtrList &list);
ValuePtr ScalarMod(const ValuePtrList &list);
ValuePtr ScalarPow(const ValuePtrList &list);
ValuePtr ScalarFloordiv(const ValuePtrList &list);
ValuePtr ScalarUAdd(const ValuePtrList &list);
ValuePtr ScalarUSub(const ValuePtrList &list);
ValuePtr ScalarLog(const ValuePtrList &list);
ValuePtr ScalarEq(const ValuePtrList &list);
ValuePtr ScalarLt(const ValuePtrList &list);
ValuePtr ScalarGt(const ValuePtrList &list);
ValuePtr ScalarNe(const ValuePtrList &list);
ValuePtr ScalarLe(const ValuePtrList &list);
ValuePtr ScalarGe(const ValuePtrList &list);
ValuePtr BoolNot(const ValuePtrList &list);
ValuePtr BoolAnd(const ValuePtrList &list);
ValuePtr BoolOr(const ValuePtrList &list);
ValuePtr BoolEq(const ValuePtrList &list);
std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2);
} // namespace prim
} // namespace mindspore

View File

@ -66,7 +66,7 @@ const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail");
// Apply a function of two arguments cumulatively to the items of a sequence,
// from left to right, so as to reduce the sequence to a single value.For example,
// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5).
AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) {
AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) {
std::shared_ptr<Any> ret;
size_t size = list.size();
if (size < 2) {
@ -88,7 +88,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) {
return ret;
}
AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector<AnfNodePtr>& list) {
AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) {
size_t size = list.size();
if (size < 2) {
MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2";
@ -121,7 +121,7 @@ void HyperMap::Init() {
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
}
HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf)
HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
: MetaFuncGraph("hyper_map"),
fn_leaf_(fn_leaf),
broadcast_(false),
@ -129,13 +129,13 @@ HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf)
Init();
}
HyperMap::HyperMap(const HyperMap& h)
HyperMap::HyperMap(const HyperMap &h)
: MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
Init();
}
AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
const ArgsPairList& arg_map) {
AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> inputs;
if (fn_arg != nullptr) {
@ -145,17 +145,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf
}
(void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
[](const std::pair<AnfNodePtr, Any>& item) { return item.first; });
[](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
return func_graph->NewCNode(inputs);
}
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph,
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) {
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::static_pointer_cast<List>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
@ -179,7 +179,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph
(void)std::transform(
arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
[&func_graph, i](const std::pair<AnfNodePtr, Any>& item) {
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
});
@ -188,13 +188,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph
return func_graph->NewCNode(inputs);
}
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph,
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) {
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::static_pointer_cast<Tuple>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
@ -226,8 +226,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGrap
return func_graph->NewCNode(inputs);
}
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph,
const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(func_graph);
@ -257,11 +257,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGrap
return func_graph->NewCNode(inputs);
}
AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) {
AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
bool found = false;
TypeId id = kObjectTypeEnd;
std::pair<AnfNodePtr, TypePtr> pair;
for (auto& item : arg_map) {
for (auto &item : arg_map) {
pair = item;
id = item.second->type_id();
if (nonleaf_.count(id)) {
@ -272,7 +272,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
if (found) {
// In a nonleaf situation, all arguments must have the same generic.
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr>& item) {
bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
if (item.first != pair.first) {
return item.second->type_id() != pair.second->type_id();
}
@ -283,7 +283,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n";
int idx = 0;
for (auto& item : arg_map) {
for (auto &item : arg_map) {
oss << ++idx << ": " << item.second->ToString() << "\n";
}
MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
@ -308,14 +308,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a
}
}
ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) {
ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
TypePtr type_tensor = std::make_shared<TensorType>();
bool flag = std::any_of(
args_spec_list.begin(), args_spec_list.end(),
[type_tensor](const std::pair<AnfNodePtr, TypePtr>& item) { return IsSubType(item.second, type_tensor); });
[type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
if (flag && broadcast_) {
ArgsPairList ret;
for (auto& item : args_spec_list) {
for (auto &item : args_spec_list) {
if (!IsSubType(item.second, type_tensor)) {
TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
ret.push_back(
@ -329,7 +329,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL
return args_spec_list;
}
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) {
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->debug_info()->set_name("hyper_map");
@ -353,7 +353,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) {
return ptrGraph;
}
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const {
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
if (fn_leaf_ == nullptr) {
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
// Assert that hypermap's function param does not contain free variables
@ -368,20 +368,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList&
AbstractBasePtrList broadened;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
[](const AbstractBasePtr& arg) -> AbstractBasePtr {
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->Broaden();
});
return broadened;
}
REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
(void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
.def(py::init<>());
}));
FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) {
FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) {
MS_EXCEPTION_IF_NULL(a_tuple);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
@ -401,7 +401,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu
return ret;
}
FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) {
FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) {
MS_EXCEPTION_IF_NULL(a_list);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
@ -421,7 +421,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list
return ret;
}
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "tail requires a non-empty tuple.";
}
@ -441,11 +441,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list)
}
REGISTER_PYBIND_DEFINE(
Tail_, ([](const py::module* m) {
(void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string&>());
Tail_, ([](const py::module *m) {
(void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
}));
FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
int tuple_size = SizeToInt(args_spec_list.size());
std::ostringstream ss;
@ -486,7 +486,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg
return fg;
}
GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param)
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
if (get_by_list) {
signatures_ =
@ -496,8 +496,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_
}
}
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights,
const std::vector<AnfNodePtr>& params_list, bool applyJ) {
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &params_list, bool applyJ) {
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true);
@ -537,7 +537,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights,
return ret;
}
void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -590,7 +590,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An
}
// Generate the graph.
FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() < 1) {
MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is "
<< args_spec_list.size() << ".";
@ -637,21 +637,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp
return dfBuilder;
}
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
(void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
*m, "GradOperation_")
.def(py::init<std::string&>(), py::arg("fn"))
.def(py::init<std::string&, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
.def(py::init<std::string &>(), py::arg("fn"))
.def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
py::arg("get_by_list"), py::arg("sens_param"));
}));
MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) {
MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
fn_cache_.clear();
signatures_ = std::vector<Signature>({// def multitype(*args:ref):
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
}
void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) {
void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) {
@ -660,7 +660,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn)
fn_cache_[types] = s_fn;
}
void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) {
void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ").";
auto fn = fn_cache_.find(types);
if (fn != fn_cache_.end()) {
@ -669,9 +669,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function&
fn_cache_py_[types] = py_fn;
}
void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, const py::function& py_fn) {
void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) {
TypePtrList types;
for (auto& type_name : types_name) {
for (auto &type_name : types_name) {
auto type_ptr = StringToType(type_name);
if (type_ptr == nullptr) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error ";
@ -681,7 +681,7 @@ void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, co
Register(types, py_fn);
}
void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) {
void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
std::vector<std::string> types_name;
for (size_t it = 0; it < tuple.size(); ++it) {
py::object name_py = tuple[it];
@ -693,16 +693,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function&
}
Register(types_name, py_fn);
}
static TypePtr UnwrapRef(const TypePtr& type) {
static TypePtr UnwrapRef(const TypePtr &type) {
if (type->isa<RefType>()) {
return type->cast<RefTypePtr>()->subtype();
}
return type;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
bool find_fn = false;
py::function py_fn;
for (auto& item : fn_cache_py_) {
for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first;
if (sign.size() != types.size()) {
continue;
@ -735,7 +735,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
<< "`, corresponding location info:\n";
int idx = 0;
for (auto& item : fn_cache_py_) {
for (auto &item : fn_cache_py_) {
FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
if (func_graph == nullptr) {
MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
@ -747,15 +747,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) {
<< oss.str();
}
REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
(void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
*m, "MultitypeFuncGraph_")
.def(py::init<std::string&>())
.def(py::init<std::string &>())
.def("register_fn", &MultitypeFuncGraph::PyRegister);
}));
// Generate the ListMap func graph.
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
size_t args_num = args_spec_list.size();
// args: fn, list1, list2, ...
if (args_num < 2) {
@ -821,8 +821,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis
return fg_ptr;
}
void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgnext_ptr,
const FuncGraphPtr& fg_ptr) {
void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
const FuncGraphPtr &fg_ptr) {
MS_EXCEPTION_IF_NULL(fg_ptr);
AnfNodePtr fn = fg_ptr->add_parameter();
@ -858,8 +858,8 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr&
fgtrue_ptr->set_output(output_cnode);
}
void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgcond_ptr,
const FuncGraphPtr& fg_ptr) {
void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
const FuncGraphPtr &fg_ptr) {
MS_EXCEPTION_IF_NULL(fg_ptr);
AnfNodePtr fn = fg_ptr->add_parameter();
@ -893,7 +893,7 @@ void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr&
fg_ptr->set_output(output_cnode);
}
FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// args: tuple1, tuple2
abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
AbstractBasePtr abs_a = args_spec_list[0];
@ -928,7 +928,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li
return ret;
}
int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) {
int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
MS_EXCEPTION_IF_NULL(scalar);
return GetValue<int>(scalar->BuildValue());
}
@ -942,7 +942,7 @@ int GetPositiveIndex(int index, int length) {
return index;
}
int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) {
int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) {
MS_EXCEPTION_IF_NULL(member);
if (member->isa<AbstractScalar>()) {
@ -957,8 +957,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std
<< member->ToString();
}
void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index,
int* stop_index, int* step_value) {
void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index,
int *stop_index, int *step_value) {
MS_EXCEPTION_IF_NULL(tuple);
MS_EXCEPTION_IF_NULL(slice);
MS_EXCEPTION_IF_NULL(start_index);
@ -998,7 +998,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl
}
}
FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tuple
// args: tuple, start index, end index, step
const std::string op_name("TupleSlice");
@ -1032,7 +1032,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
return ret;
}
int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) {
int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
unsigned int number_dec = 0;
for (size_t index = 0; index < number_bin.size(); index++) {
number_dec |= number_bin[index] << index;
@ -1040,8 +1040,8 @@ int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) {
return static_cast<int>(number_dec);
}
void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vector<int>* end,
std::vector<int>* strides, int length) {
void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
std::vector<int> *strides, int length) {
MS_EXCEPTION_IF_NULL(slice);
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
@ -1064,8 +1064,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vec
strides->push_back(step_value);
}
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector<int>& shape,
std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) {
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(slice_tuple);
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
@ -1111,8 +1111,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple,
return ConvertBinaryToDecimal(shrink);
}
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector<int>& shape,
std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) {
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides);
@ -1132,9 +1132,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const
return 0;
}
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector<int>& shape,
std::vector<int>* begin, std::vector<int>* end,
std::vector<int>* strides) {
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
std::vector<int> *begin, std::vector<int> *end,
std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides);
@ -1153,7 +1153,7 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co
return 1;
}
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tensor
// args: tensor, slice or slice tuple
const std::string op_name = std::string("TensorSlice");
@ -1177,7 +1177,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
} else {
std::ostringstream args_info;
for (const auto& arg : args_spec_list) {
for (const auto &arg : args_spec_list) {
MS_EXCEPTION_IF_NULL(arg);
args_info << arg->ToString() << "\n";
}
@ -1199,19 +1199,19 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec
return ret_graph;
}
REGISTER_PYBIND_DEFINE(
TupleAdd_, ([](const py::module* m) {
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>());
}));
REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) {
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
.def(py::init<std::string&>());
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
.def(py::init<std::string&>());
.def(py::init<std::string &>());
}));
} // namespace prim
} // namespace mindspore

View File

@ -47,20 +47,20 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
class MultitypeFuncGraph : public MetaFuncGraph {
public:
explicit MultitypeFuncGraph(const std::string& name);
explicit MultitypeFuncGraph(const std::string &name);
~MultitypeFuncGraph() override = default;
MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
using specialize_fn = FuncGraph* (*)(TypePtrList);
using specialize_fn = FuncGraph *(*)(TypePtrList);
// Register a method which specialize based on types vectors;
virtual void Register(const TypePtrList& types, specialize_fn s_fn);
virtual void Register(const TypePtrList& types, const py::function& py_fn);
virtual void Register(const std::vector<std::string>& types_name, const py::function& py_fn);
virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn);
virtual void Register(const TypePtrList &types, specialize_fn s_fn);
virtual void Register(const TypePtrList &types, const py::function &py_fn);
virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn);
virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override;
FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual>& GetPyFunctions() const {
const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const {
return fn_cache_py_;
}
@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>;
class HyperMap : public MetaFuncGraph {
public:
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr);
HyperMap(const HyperMap& h);
explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
HyperMap(const HyperMap &h);
void Init();
HyperMap& operator=(const HyperMap& h) {
HyperMap &operator=(const HyperMap &h) {
if (this != &h) {
fn_leaf_ = h.fn_leaf_;
broadcast_ = h.broadcast_;
@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph {
~HyperMap() override = default;
MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override;
FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override;
abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
private:
AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
const ArgsPairList& arg_map);
AnfNodePtr FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
const ArgsPairList& arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
const ArgsPairList& arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg,
const ArgsPairList& arg_map);
AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map);
ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list);
AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_map);
AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_map);
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
MultitypeFuncGraphPtr fn_leaf_;
bool broadcast_;
@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr<HyperMap>;
class HyperMapPy : public HyperMap {
public:
explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr) : HyperMap(fn_leaf) {}
explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
~HyperMapPy() override = default;
MS_DECLARE_PARENT(HyperMapPy, HyperMap)
};
@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap;
class Tail : public MetaFuncGraph {
public:
explicit Tail(const std::string& name) : MetaFuncGraph(name) {}
explicit Tail(const std::string &name) : MetaFuncGraph(name) {}
~Tail() override = default;
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple);
FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple);
FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list);
friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; }
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
};
using TailPtr = std::shared_ptr<Tail>;
class MakeTupleGradient : public MetaFuncGraph {
public:
explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {}
explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
~MakeTupleGradient() override = default;
MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
};
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
class GradOperation : public MetaFuncGraph {
public:
explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false,
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
bool sens_param = false);
~GradOperation() override = default;
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams,
FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
bool applyJ = false);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
bool sens_param() const { return sens_param_; }
bool get_all_;
bool get_by_list_;
bool sens_param_;
private:
void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem);
};
using GradOperationPtr = std::shared_ptr<GradOperation>;
class ListMap {
public:
explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); }
explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
~ListMap() = default;
void MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr);
void MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list);
void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
private:
std::string name_;
@ -181,31 +181,31 @@ class ListMap {
class TupleAdd : public MetaFuncGraph {
public:
explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {}
explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
~TupleAdd() override = default;
MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
};
using TupleAddPtr = std::shared_ptr<TupleAdd>;
class TupleSlice : public MetaFuncGraph {
public:
explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {}
explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
~TupleSlice() override = default;
MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
};
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
class TensorSlice : public MetaFuncGraph {
public:
explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {}
explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {}
~TensorSlice() override = default;
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
};
using TensorSlicePtr = std::shared_ptr<TensorSlice>;

View File

@ -34,7 +34,7 @@ namespace prim {
namespace {
using PatternListType = std::initializer_list<BaseRef>;
const std::vector<Signature>& GetSignature(const ValuePtr& function) {
const std::vector<Signature> &GetSignature(const ValuePtr &function) {
static const auto empty = std::vector<Signature>();
if (function->isa<Primitive>()) {
return function->cast<PrimitivePtr>()->signatures();
@ -44,8 +44,8 @@ const std::vector<Signature>& GetSignature(const ValuePtr& function) {
return empty;
}
void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list,
const std::vector<Signature>& signature, bool has_var, std::vector<AnfNodePtr>* op_inputs) {
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *op_inputs) {
std::size_t sig_size = signature.size();
auto positional_size = sig_size;
if (has_var) {
@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg
}
// Get the largest type of index in the same SignatureEnumDType of arguments.
std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType>& dtypes,
const abstract::AbstractBasePtrList& args_spec_list) {
std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType> &dtypes,
const abstract::AbstractBasePtrList &args_spec_list) {
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
@ -89,7 +89,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
continue;
}
for (const auto& index : indexs) {
for (const auto &index : indexs) {
AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
@ -104,7 +104,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
return dst_type;
}
AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const FuncGraphPtr& graph) {
AnfNodePtr DoCast(const AnfNodePtr &param, const AnfNodePtr &source_param, const FuncGraphPtr &graph) {
// op and module import path
auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional");
MS_EXCEPTION_IF_NULL(prim_dtype);
@ -116,11 +116,11 @@ AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const
return NewCNode({cast_node, param, dtype_node}, graph);
}
void DoAutoCast(const std::vector<Signature>& signature, const abstract::AbstractBasePtrList& args_spec_list,
const FuncGraphPtr& graph, std::vector<AnfNodePtr>* op_inputs) {
void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list,
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *op_inputs) {
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature& sig) { return sig.dtype; });
[](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return;
@ -143,10 +143,10 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac
}
}
AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function,
const AbstractBasePtrList& args_spec_list, const std::vector<AnfNodePtr>& params_list) {
AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> &params_list) {
// args: original inputs
auto& signature = GetSignature(function);
auto &signature = GetSignature(function);
std::size_t sig_size = signature.size();
auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
if (sig_size > 0) {
@ -196,13 +196,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func
}
} // namespace
AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function,
const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) {
AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) {
auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs);
return new_cnode;
}
FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
for (size_t i = 0; i < args_spec_list.size(); ++i) {

View File

@ -37,17 +37,17 @@ namespace mindspore {
namespace prim {
class DoSignatureMetaFuncGraph : public MetaFuncGraph {
public:
explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function)
explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function)
: MetaFuncGraph("S-" + name), function_(function) {}
~DoSignatureMetaFuncGraph() override = default;
MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override;
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override;
const ValuePtr function() const { return function_; }
friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) {
friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) {
return &lhs == &rhs;
}
@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph {
};
using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>;
AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function,
const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs);
AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs);
} // namespace prim
} // namespace mindspore

View File

@ -27,7 +27,7 @@
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) {
FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
abstract::CheckArgsSize("ListAppend", args_list, 2);
AbstractBasePtr arg0 = args_list[0];
@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList&
return ret;
}
REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) {
(void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_")
.def(py::init<std::string&>());
.def(py::init<std::string &>());
}));
} // namespace prim
} // namespace mindspore

View File

@ -28,15 +28,15 @@ namespace mindspore {
namespace prim {
class ListAppend : public MetaFuncGraph {
public:
explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {}
explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {}
~ListAppend() override = default;
MS_DECLARE_PARENT(ListAppend, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override;
friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) {
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) {
os << list_append.name_;
return os;
}
friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; }
friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; }
};
using ListAppendPtr = std::shared_ptr<ListAppend>;
} // namespace prim

View File

@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tensor
// args: tensor, slice or slice tuple
const std::string op_name = std::string("UnpackCall");
@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
AnfNodePtr para_dict = ret_graph->add_parameter();
auto dict_elems = arg_dict->elements();
(void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
[ret_graph, para_dict](const AbstractAttribute& item) {
[ret_graph, para_dict](const AbstractAttribute &item) {
auto dict_get_item = ret_graph->NewCNode(
{NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
return ret_graph->NewCNode(
@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_
return ret_graph;
}
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) {
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
.def(py::init<std::string&>());
.def(py::init<std::string &>());
}));
} // namespace prim

View File

@ -40,11 +40,11 @@ namespace prim {
// and generate positional parameters and key-value pairs for function.
class UnpackCall : public MetaFuncGraph {
public:
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {}
explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {}
~UnpackCall() override = default;
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; }
};
using UnpackCallPtr = std::shared_ptr<UnpackCall>;

View File

@ -36,7 +36,7 @@ namespace prim {
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractTuple;
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// zip operation:
// input: tuple arguments
// output: tuple of items of input iterated on every input
@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
MS_LOG(EXCEPTION) << "zip arguments input should not be empty";
}
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool {
auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
return abs->isa<AbstractTuple>();
});
@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
}
auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(),
[](const AbstractBasePtr& x, const AbstractBasePtr& y) {
[](const AbstractBasePtr &x, const AbstractBasePtr &y) {
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
});
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe
return ret_graph;
}
REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) {
(void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m,
"ZipOperation_")
.def(py::init<std::string&>());
.def(py::init<std::string &>());
}));
} // namespace prim
} // namespace mindspore

View File

@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr;
class ZipOperation : public MetaFuncGraph {
public:
explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {}
explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {}
~ZipOperation() override = default;
MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) {
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) {
os << op.name_;
return os;
}
friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; }
friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; }
};
using ZipOperationPtr = std::shared_ptr<ZipOperation>;
} // namespace prim

View File

@ -238,7 +238,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) {
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) {
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
ValuePtr node = nullptr;
bool succ = parse::ConvertData(obj, &node);

View File

@ -26,8 +26,8 @@
namespace mindspore {
// namespace to support primitive operators
namespace prim {
ValuePtr GetPythonOps(const std::string& op_name,
const std::string& module_name = "mindspore._extends.parse.standard_method");
ValuePtr GetPythonOps(const std::string &op_name,
const std::string &module_name = "mindspore._extends.parse.standard_method");
// Arithmetic
extern const PrimitivePtr kPrimScalarAdd;
@ -241,7 +241,7 @@ extern const PrimitivePtr kPrimVirtualDataset;
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function)
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
: Primitive("S-Prim-" + name), function_(function) {}
~DoSignaturePrimitive() override = default;
@ -257,7 +257,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
class UnpackGraphPrimitive : public Primitive {
public:
explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args)
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)
: Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {}
~UnpackGraphPrimitive() override = default;
MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive)

View File

@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction()
{"scalar_sub", kPrimTypeTwoArgs},
{"scalar_floordiv", kPrimTypeTwoArgs}}) {}
bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const {
bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const {
bool result = false;
if (func != nullptr) {
@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu
return result;
}
int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const {
int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const {
MS_EXCEPTION_IF_NULL(prim);
int prim_type = static_cast<int>(kPrimTypeUnknown);

View File

@ -41,21 +41,21 @@ class PrimToFunction;
class PrimToFunction {
public:
// Return a thread-safe singleton instance
static PrimToFunction& GetInstance() {
static PrimToFunction &GetInstance() {
static PrimToFunction instance;
return instance;
}
PrimToFunction(const PrimToFunction&) = delete;
PrimToFunction& operator=(const PrimToFunction&) = delete;
PrimToFunction(const PrimToFunction &) = delete;
PrimToFunction &operator=(const PrimToFunction &) = delete;
~PrimToFunction() = default;
// Get the args and return value for a primitive instance.
bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const;
bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const;
private:
PrimToFunction();
// Get the number of primitive arguments
int GetPrimType(const PrimitivePtr& prim) const;
int GetPrimType(const PrimitivePtr &prim) const;
const std::unordered_map<std::string, int> prim_func_type_map_;
};
} // namespace prim

View File

@ -24,7 +24,7 @@
namespace mindspore {
namespace ad {
Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller)
Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller)
: primal_(primal), caller_(caller), dout_(nullptr) {
if (k != nullptr) {
k_ = k;
@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP
AnfNodePtr Adjoint::k() { return k_; }
void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); }
void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); }
void Adjoint::UpdateK(const AnfNodePtr& new_k) {
void Adjoint::UpdateK(const AnfNodePtr &new_k) {
MS_EXCEPTION_IF_NULL(new_k);
MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString();
// In recursive case, it needs update.
for (auto& user : k_user_) {
for (auto &user : k_user_) {
MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k"
<< new_k->ToString();
if (user.first->input(user.second) != k_) {
@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; }
AnfNodePtr Adjoint::dout() { return dout_hole_; }
void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) {
void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) {
dout_user_.emplace_back(std::make_pair(user, index));
}
void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) {
void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) {
if (dout_ != nullptr) {
MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString();
auto add = prim::GetPythonOps("hyper_add");
@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) {
void Adjoint::CallDoutHole() {
if (dout_ != nullptr) {
for (auto& user : dout_user_) {
for (auto &user : dout_user_) {
MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout "
<< dout_->ToString();
if (user.first->input(user.second) != dout_hole_) {

View File

@ -28,15 +28,15 @@ namespace mindspore {
namespace ad {
class Adjoint {
public:
Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller);
Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller);
~Adjoint() = default;
AnfNodePtr primal();
AnfNodePtr k();
void UpdateK(const AnfNodePtr& k);
void RegisterKUser(const CNodePtr& user, size_t index);
void UpdateK(const AnfNodePtr &k);
void RegisterKUser(const CNodePtr &user, size_t index);
AnfNodePtr dout();
void AccumulateDout(const AnfNodePtr& dout_factor);
void RegisterDoutUser(const CNodePtr& user, size_t index);
void AccumulateDout(const AnfNodePtr &dout_factor);
void RegisterDoutUser(const CNodePtr &user, size_t index);
void CallDoutHole();
private:

View File

@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractTuple;
static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
if (t == nullptr) {
return nullptr;
}
@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
AbstractBasePtrList baselist;
auto attributes = abs_class->attributes();
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
[](const AbstractAttribute& item) { return item.second; });
[](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractDictionary>()) {
auto abs_dict = dyn_cast<AbstractDictionary>(t);
AbstractBasePtrList baselist;
auto elements = abs_dict->elements();
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
[](const AbstractAttribute& item) { return item.second; });
[](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractList>()) {
auto abs_dict = dyn_cast<AbstractList>(t);
@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) {
return res;
}
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
// Inputs should be [getattr, data, attribute]
MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs.");
@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
auto ct = dyn_cast<AbstractClass>(dt);
const auto& cmap = ct->attributes();
const auto &cmap = ct->attributes();
int count = 0;
for (auto& item : cmap) {
for (auto &item : cmap) {
if (cons_is_str && item.first == cons_str) {
break;
}
@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
}
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
// Inputs should be [dict_getitem, dict, item]
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs.");
AnfNodePtr data = inputs[1];
@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
const auto& cmap = ct->elements();
const auto &cmap = ct->elements();
int count = 0;
for (auto& item : cmap) {
for (auto &item : cmap) {
if (cons_is_str && item.first == cons_str) {
break;
}
@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
}
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) {
AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) {
return node->func_graph()->NewCNode(inputs);
}
AnfNodePtr ErasePartialNode(const CNodePtr& node) {
AnfNodePtr ErasePartialNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
// Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs.");
@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) {
return nullptr;
}
AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) {
AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) {
return node->func_graph()->NewCNode(inputs);
}
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) {
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
// Inputs should be [list_getitem, list, item]
if (inputs.size() < 3) {
MS_LOG(EXCEPTION) << "Node's input number < 3.";
@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
}
AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) {
AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
// Inputs should be [list_setitem, list, index, item]
if (inputs.size() < 4) {
MS_LOG(EXCEPTION) << "Node's input number < 4.";
@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
}
AnfNodePtr EraseMakeDictNode(const CNodePtr& node) {
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs");
return inputs[2];
}
AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) {
AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
// Inputs should be [make_keyword_arg, key, value]
MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs");
return inputs[2];
}
AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) {
AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto& inputs = node->inputs();
const auto &inputs = node->inputs();
// Inputs should be [extract_keyword_arg, arg, key]
MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs");
return inputs[2];
}
ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) {
ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) {
const int DEPTH_MAX = 5;
if (depth > DEPTH_MAX) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels.";
}
std::vector<ValuePtr> elements;
for (const auto& it : value_list->value()) {
for (const auto &it : value_list->value()) {
ValuePtr value = nullptr;
if (it->isa<ValueList>()) {
value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d
return std::make_shared<ValueTuple>(elements);
}
AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) {
AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
ValuePtr value = node->value();
auto value_list = value->cast<ValueListPtr>();
@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) {
// Convert class to Tuple
// Convert getattr to getitem
// Convert make_record to make_tuple
void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) {
void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet all_node = manager->all_nodes();
for (auto& node : all_node) {
for (auto &node : all_node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
AnfNodePtr new_node = nullptr;
@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr&
}
}
for (auto& node : manager->all_nodes()) {
for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract());
node->set_abstract(ret);
}
}
// expand tuples in graph parameters
static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph,
const std::vector<AnfNodePtr>& params) {
static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
const std::vector<AnfNodePtr> &params) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_params;
for (const auto& param : params) {
for (const auto &param : params) {
MS_EXCEPTION_IF_NULL(param);
auto param_abs = param->abstract();
MS_EXCEPTION_IF_NULL(param_abs);
@ -350,7 +350,7 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con
std::vector<AnfNodePtr> new_param;
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
for (auto& elem : abs_tuple->elements()) {
for (auto &elem : abs_tuple->elements()) {
auto np = std::make_shared<Parameter>(func_graph);
np->set_abstract(elem);
new_param.emplace_back(np);
@ -366,11 +366,11 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con
}
// expand tuples in graph applies
static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const std::vector<AnfNodePtr>& inputs) {
static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> new_inputs;
for (const auto& input : inputs) {
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
auto input_abs = input->abstract();
@ -391,7 +391,7 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st
int idx = 0;
std::vector<AnfNodePtr> new_input;
auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
for (auto& elem : abs_tuple->elements()) {
for (auto &elem : abs_tuple->elements()) {
auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx));
c_node->input(2)->set_abstract(aptr);
@ -416,19 +416,19 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st
// tuples in Graph's parameters: AbstractTuple (a, b, c) -->
// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
// cppcheck-suppress unusedFunction
void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) {
void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
// NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet all_node = manager->all_nodes();
for (auto& node : all_node) {
for (auto &node : all_node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
const auto& inputs = cnode->inputs();
const auto &inputs = cnode->inputs();
// Bypass the first input in inputs as it's fn.
if (!IsValueNode<Primitive>(inputs[0])) {
@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) {
}
FuncGraphSet all_graph = manager->func_graphs();
for (auto& func_graph : all_graph) {
for (auto &func_graph : all_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
manager->SetParameters(func_graph, expand_p);

View File

@ -22,7 +22,7 @@
namespace mindspore {
namespace opt {
// Automatically adding control depend based on effect order and side effect analysis.
void AddControlDepend(const FuncGraphPtr& graph);
void AddControlDepend(const FuncGraphPtr &graph);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_

View File

@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func
nodes.push_back(func_node);
// {unpackcall, {GradOperation, ...}, args...}
std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr& node) { return node; });
[](const AnfNodePtr &node) { return node; });
unpack_graph_node = func_graph->NewCNode(nodes);
} else {
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func
nodes.push_back(func_node);
// {{GradOperation, ...}, args...}
std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
[](const AnfNodePtr& node) { return node; });
[](const AnfNodePtr &node) { return node; });
unpack_graph_node = func_graph->NewCNode(nodes);
}
return unpack_graph_node;
}
// get metagraph of value node
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
ValuePtr value;
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
}
// check if node is a specific metafuncgraph op
bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) {
bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) {
if (node != nullptr) {
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
if (meta_func_graph_ptr == nullptr) {
@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) {
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}

View File

@ -31,20 +31,20 @@
namespace mindspore {
/* namespace to support opt */
namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim,
const RenormAction& renorm_action) {
auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); };
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) {
auto fn = [prims](const AnfNodePtr& node) -> bool {
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) {
return false;
}
for (auto& prim : prims) {
for (auto &prim : prims) {
if (IsPrimitiveCNode(node, prim)) {
return true;
}
@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
const PredicateFuncType& predicate, const RenormAction& renorm_action) {
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
}
AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const {
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const {
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode
return result;
}
bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node,
const SubstitutionPtr& transform) const {
bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node,
const SubstitutionPtr &transform) const {
FuncGraphManagerPtr manager = optimizer->manager();
std::unordered_set<AnfNodePtr> seen_node;
std::deque<AnfNodePtr> todo{root_node};
@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo
}
if (node->isa<CNode>()) {
auto& inputs = node->cast<CNodePtr>()->inputs();
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
}
auto& node_users = manager->node_users();
auto &node_users = manager->node_users();
if (change && node_users.find(node) != node_users.end()) {
for (auto& use : node_users[node]) {
for (auto &use : node_users[node]) {
auto use_node = use.first;
todo.push_back(use_node);
if (seen_node.find(use_node) != seen_node.end()) {
@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo
return changes;
}
bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const {
bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const {
MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = optimizer->manager();
@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize
do {
loop = false;
for (auto const& transform : list_) {
for (auto const &transform : list_) {
auto change = ApplyTransform(optimizer, func_graph->output(), transform);
changes = changes || change;
loop = loop || change;

View File

@ -28,7 +28,7 @@
namespace mindspore {
namespace parallel {
std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) {
std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint32_t recursive_times = 0) {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
@ -39,7 +39,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
MS_EXCEPTION_IF_NULL(manager);
auto node_set = manager->node_users()[para];
std::unordered_set<CNodePtr> cnode_set;
for (auto& node_pair : node_set) {
for (auto &node_pair : node_set) {
auto cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
@ -54,7 +54,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
(void)cnode_set.emplace(cnode);
} else {
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
for (auto& cnode_sub : cnode_set_sub) {
for (auto &cnode_sub : cnode_set_sub) {
(void)cnode_set.emplace(cnode_sub);
}
}
@ -63,8 +63,8 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t
}
Status AllreduceFusion::AddNodeToGraph() {
const auto& parameters = root_graph_->parameters();
for (auto& parameter : parameters) {
const auto &parameters = root_graph_->parameters();
for (auto &parameter : parameters) {
if (!ParameterRequireGrad(parameter)) {
continue;
}
@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() {
if (cnode_set.empty()) {
continue;
}
for (auto& cnode : cnode_set) {
for (auto &cnode : cnode_set) {
MS_LOG(DEBUG) << "AddNode " << cnode->DebugString();
if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) {
MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString();
@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() {
return SUCCESS;
}
CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const {
CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi
return cnode_dist;
} else {
auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1);
for (auto& ele : cnode_dist_next) {
for (auto &ele : cnode_dist_next) {
cnode_dist[ele.first] = cost + ele.second;
}
}
} else {
auto cnode_dist_next = FindNextCNodes(cnode);
for (auto& ele : cnode_dist_next) {
for (auto &ele : cnode_dist_next) {
cnode_dist[ele.first] = ele.second;
}
}
return cnode_dist;
}
CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const {
CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
}
const auto& from_inputs = from->inputs();
const auto &from_inputs = from->inputs();
std::unordered_map<CNodePtr, double> dist_map;
MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs";
for (auto& input_node : from_inputs) {
for (auto &input_node : from_inputs) {
auto cnode_dist = FindCNode(input_node, recursive_times + 1);
for (auto& ele : cnode_dist) {
for (auto &ele : cnode_dist) {
(void)dist_map.emplace(ele);
}
}
@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu
Status AllreduceFusion::AddEdgeToGraph() {
std::unordered_map<CNodePtr, int32_t> cnode_state_map;
const auto& cnodes = allreduce_graph_.cnode_set();
for (auto& cnode : cnodes) {
const auto &cnodes = allreduce_graph_.cnode_set();
for (auto &cnode : cnodes) {
cnode_state_map[cnode] = 0;
}
const auto& head_cnode = allreduce_graph_.head_cnode();
const auto &head_cnode = allreduce_graph_.head_cnode();
std::queue<CNodePtr> cnode_queue;
cnode_queue.emplace(head_cnode);
cnode_state_map[head_cnode] = 1;
@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() {
cnode_queue.pop();
cnode_state_map[cur_cnode] = 2;
auto next = FindNextCNodes(cur_cnode);
for (auto& ele : next) {
auto& cnode = ele.first;
auto& dist = ele.second;
for (auto &ele : next) {
auto &cnode = ele.first;
auto &dist = ele.second;
if (cnode_state_map[cnode] == 0) {
cnode_queue.emplace(cnode);
cnode_state_map[cnode] = 1;
@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() {
return SUCCESS;
}
std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) {
std::vector<CNodePtr> FindMirror(const AnfNodePtr &para, uint32_t recursive_times = 0) {
if (recursive_times > MAX_RECURSIVE_CALL_TIMES) {
MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is "
<< MAX_RECURSIVE_CALL_TIMES;
@ -184,7 +184,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[para];
std::vector<CNodePtr> cnode_list;
for (auto& node_pair : node_set) {
for (auto &node_pair : node_set) {
auto cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
@ -210,7 +210,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time
return cnode_list;
}
void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) {
void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string &parameter_name) {
MS_EXCEPTION_IF_NULL(mirror_cnode);
MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st
(void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
}
Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) {
Status FindMirrorAndSetFusion(const AnfNodePtr &para, int32_t fusion) {
auto mirror_cnodes = FindMirror(para);
if (mirror_cnodes.empty()) {
MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found.";
return SUCCESS;
}
if (mirror_cnodes.size() > 2) {
for (auto& mirror_cnode : mirror_cnodes) {
for (auto &mirror_cnode : mirror_cnodes) {
MS_EXCEPTION_IF_NULL(mirror_cnode);
MS_LOG(INFO) << mirror_cnode->DebugString();
}
@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) {
<< "Mirror CNode found.";
return FAILED;
}
for (auto& mirror_cnode : mirror_cnodes) {
for (auto &mirror_cnode : mirror_cnodes) {
auto parameter_name = ParameterName(para);
SetMirrorFusion(mirror_cnode, fusion, parameter_name);
}
return SUCCESS;
}
Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusion) {
for (auto& param_node : paras) {
Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> &paras, int32_t fusion) {
for (auto &param_node : paras) {
if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) {
MS_LOG(ERROR) << "FindMirrorAndSetFusion failed";
return FAILED;
@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusi
return SUCCESS;
}
Status AllreduceFusion::SetFusion(const std::vector<double>& cost_map) {
Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) {
if (cost_map.size() < 2) {
MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size();
return FAILED;
@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) {
return SetFusionByBackwardCompAndAllreduceTime();
}
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) {
Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
if (ret == nullptr) {
MS_LOG(ERROR) << "ret is nullptr.";
return FAILED;

View File

@ -50,15 +50,15 @@ class AllreduceFusion {
allreduce_bandwidth_(0),
computation_time_parameter_(0) {}
virtual ~AllreduceFusion() = default;
Status ProcessAllreduceFusion(const CNodePtr& ret);
Status ProcessAllreduceFusion(const CNodePtr &ret);
private:
Status AddNodeToGraph();
CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const;
CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const;
CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const;
CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const;
Status AddEdgeToGraph();
std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const;
Status SetFusion(const std::vector<double>& cost_map);
Status SetFusion(const std::vector<double> &cost_map);
Status SetFusionByAlgorithm(int32_t algorithm);
Status SetFusionByBackwardCompTime();
Status SetFusionByBackwardCompAndAllreduceTime();

View File

@ -23,7 +23,7 @@
namespace mindspore {
namespace parallel {
Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr &para) {
AllreduceNodePtr arnode;
auto cnode_emplace_return = cnode_set_.emplace(node);
if (!cnode_emplace_return.second) {
@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
return SUCCESS;
}
Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) {
Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
auto from_arnode_iter = cnode_arnode_map_.find(from);
if (from_arnode_iter == cnode_arnode_map_.end()) {
MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double
return SUCCESS;
}
bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const {
bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
auto cnode_iter = cnode_set_.find(node);
return !(cnode_iter == cnode_set_.end());
}
std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
std::vector<AnfNodePtr> nodes;
for (auto& cnode_arnode : cnode_arnode_map_) {
for (auto &cnode_arnode : cnode_arnode_map_) {
MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
<< ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
<< " curr_para_size: " << cnode_arnode.second->curr_para_size();
@ -117,7 +117,7 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
std::vector<AnfNodePtr> nodes;
double cur_para_size = 0;
double from = to;
for (auto& arnode : arnode_vec_) {
for (auto &arnode : arnode_vec_) {
if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) {
continue;
}
@ -135,14 +135,14 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
void AllreduceGraph::PrintCNodeSet() const {
MS_LOG(INFO) << "CNodeSet:";
for (auto& cnode : cnode_set_) {
for (auto &cnode : cnode_set_) {
MS_LOG(INFO) << cnode->DebugString();
}
}
void AllreduceGraph::PrintAllredueGraphInfo() const {
MS_LOG(INFO) << "max: " << max_;
for (auto& cnode_arnode : cnode_arnode_map_) {
for (auto &cnode_arnode : cnode_arnode_map_) {
MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
MS_LOG(INFO) << "arnode info: ";
cnode_arnode.second->ToString();
@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const {
void AllreduceGraph::PrintArnodeVec() const {
MS_LOG(INFO) << "ArnodeVec:";
for (auto& arnode : arnode_vec_) {
for (auto &arnode : arnode_vec_) {
arnode.ToString();
}
}
void AllreduceGraph::PrintArnodeSet() const {
MS_LOG(INFO) << "ArnodeSet:";
for (auto& arnode : arnode_set_) {
for (auto &arnode : arnode_set_) {
arnode->ToString();
}
}
void AllreduceGraph::SortArnode() {
arnode_vec_.clear();
for (auto& node : arnode_set_) {
for (auto &node : arnode_set_) {
arnode_vec_.emplace_back(*node);
}
std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() {
Status AllreduceGraph::RemoveExtraParas() {
std::unordered_set<AnfNodePtr> para_map;
for (auto& node : arnode_vec_) {
for (auto& para : node.paras()) {
for (auto &node : arnode_vec_) {
for (auto &para : node.paras()) {
auto emplac_result = para_map.emplace(para);
if (!emplac_result.second) {
MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() {
return SUCCESS;
}
Status AllreduceGraph::set_head_cnode(const CNodePtr& node) {
Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
if (arnode->Init(node) != SUCCESS) {
MS_LOG(ERROR) << "AllreduceNode Init failed";

View File

@ -42,9 +42,9 @@ class AllreduceGraph {
cnode_arnode_map_(),
max_(0) {}
virtual ~AllreduceGraph() = default;
Status AddNode(const CNodePtr& node, const AnfNodePtr& para);
Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist);
bool NodeInGraph(const CNodePtr& node) const;
Status AddNode(const CNodePtr &node, const AnfNodePtr &para);
Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist);
bool NodeInGraph(const CNodePtr &node) const;
std::vector<AnfNodePtr> GetParaByCost(double from, double to);
// Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is
// over para_size.
@ -60,9 +60,9 @@ class AllreduceGraph {
void PrintAllredueGraphInfo() const;
void PrintArnodeVec() const;
void PrintArnodeSet() const;
const std::unordered_set<CNodePtr>& cnode_set() const { return cnode_set_; }
const std::unordered_set<CNodePtr> &cnode_set() const { return cnode_set_; }
CNodePtr head_cnode() const { return head_cnode_; }
Status set_head_cnode(const CNodePtr& node);
Status set_head_cnode(const CNodePtr &node);
double max() const { return max_; }
private:

Some files were not shown because too many files have changed in this diff Show More